This commit is contained in:
Sayak Paul
2025-09-24 15:09:07 +05:30
committed by GitHub
parent 457c7c1b8d
commit 9a188eadbe
3 changed files with 17 additions and 6 deletions

8
Makefile Normal file
View File

@ -0,0 +1,8 @@
.PHONY: style
export check_dirs := src examples tests
style:
black ${check_dirs}
isort ${check_dirs}
ruff check ${check_dirs} --fix

View File

@ -20,11 +20,11 @@ activation.gelu_fast(y, x)
print("Kernel successfully executed") print("Kernel successfully executed")
# Check results # Check results
expected = torch.tensor([ expected = torch.tensor(
[0.8408, 1.9551, 2.9961], [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
[4.0000, 5.0000, 6.0000], device="cuda:0",
[7.0000, 8.0000, 9.0000] dtype=torch.float16,
], device='cuda:0', dtype=torch.float16) )
assert torch.allclose(y, expected) assert torch.allclose(y, expected)
print("Calculated values are exact") print("Calculated values are exact")

View File

@ -45,6 +45,9 @@ kernels = "kernels.cli:main"
[project.entry-points."egg_info.writers"] [project.entry-points."egg_info.writers"]
"kernels.lock" = "kernels.lockfile:write_egg_lockfile" "kernels.lock" = "kernels.lockfile:write_egg_lockfile"
[tool.isort]
profile = "black"
line_length = 119
[tool.ruff] [tool.ruff]
exclude = [ exclude = [
@ -71,4 +74,4 @@ line-length = 119
# Ignored rules: # Ignored rules:
# "E501" -> line length violation # "E501" -> line length violation
lint.ignore = ["E501"] lint.ignore = ["E501"]
lint.select = ["E", "F", "I", "W"] lint.select = ["E", "F", "W"]