mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 21:10:02 +08:00
up (#157)
This commit is contained in:
8
Makefile
Normal file
8
Makefile
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
.PHONY: style
|
||||||
|
|
||||||
|
export check_dirs := src examples tests
|
||||||
|
|
||||||
|
style:
|
||||||
|
black ${check_dirs}
|
||||||
|
isort ${check_dirs}
|
||||||
|
ruff check ${check_dirs} --fix
|
@ -20,11 +20,11 @@ activation.gelu_fast(y, x)
|
|||||||
print("Kernel successfully executed")
|
print("Kernel successfully executed")
|
||||||
|
|
||||||
# Check results
|
# Check results
|
||||||
expected = torch.tensor([
|
expected = torch.tensor(
|
||||||
[0.8408, 1.9551, 2.9961],
|
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||||
[4.0000, 5.0000, 6.0000],
|
device="cuda:0",
|
||||||
[7.0000, 8.0000, 9.0000]
|
dtype=torch.float16,
|
||||||
], device='cuda:0', dtype=torch.float16)
|
)
|
||||||
assert torch.allclose(y, expected)
|
assert torch.allclose(y, expected)
|
||||||
|
|
||||||
print("Calculated values are exact")
|
print("Calculated values are exact")
|
||||||
|
@ -45,6 +45,9 @@ kernels = "kernels.cli:main"
|
|||||||
[project.entry-points."egg_info.writers"]
|
[project.entry-points."egg_info.writers"]
|
||||||
"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
|
"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
profile = "black"
|
||||||
|
line_length = 119
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
exclude = [
|
exclude = [
|
||||||
@ -71,4 +74,4 @@ line-length = 119
|
|||||||
# Ignored rules:
|
# Ignored rules:
|
||||||
# "E501" -> line length violation
|
# "E501" -> line length violation
|
||||||
lint.ignore = ["E501"]
|
lint.ignore = ["E501"]
|
||||||
lint.select = ["E", "F", "I", "W"]
|
lint.select = ["E", "F", "W"]
|
||||||
|
Reference in New Issue
Block a user