[feat] add an uploading utility (#138)

* add an uploading utility.

* format

* remove stale files.

* black format

* sorted imports.

* up

* up

* add a test

* propagate.

* remove duplicate imports.

* Apply suggestions from code review

Co-authored-by: Daniël de Kok <me@danieldk.eu>

* up

* up

* up

* command to format all files at once would be nice.

* up

* up

* up

* Use token for upload test

* assign env better.

* docs

* polish

* up

* xfail the test for now.

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
This commit is contained in:
Sayak Paul
2025-09-16 12:26:54 +05:30
committed by GitHub
parent d383fdd4b4
commit d6b51eefb7
5 changed files with 160 additions and 1 deletions

View File

@ -0,0 +1,92 @@
import logging
import os
import re
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import List
import pytest
from huggingface_hub import model_info
from kernels.cli import upload_kernels
REPO_ID = "kernels-test/kernels-upload-test"
PY_CONTENT = """\
#!/usr/bin/env python3
def main():
print("Hello from torch-universal!")
if __name__ == "__main__":
main()
"""
@dataclass
class UploadArgs:
kernel_dir: None
repo_id: None
private: False
def next_filename(path: Path) -> Path:
"""
Given a path like foo_2050.py, return foo_2051.py.
"""
m = re.match(r"^(.*?)(\d+)(\.py)$", path.name)
if not m:
raise ValueError(
f"Filename {path.name!r} does not match pattern <prefix>_<number>.py"
)
prefix, number, suffix = m.groups()
new_number = str(int(number) + 1).zfill(len(number))
return path.with_name(f"{prefix}{new_number}{suffix}")
def get_filename_to_change(repo_filenames):
for f in repo_filenames:
if "foo" in f and f.endswith(".py"):
filename_to_change = os.path.basename(f)
break
assert filename_to_change
return filename_to_change
def get_filenames_from_a_repo(repo_id: str) -> List[str]:
try:
repo_info = model_info(repo_id=repo_id, files_metadata=True)
repo_siblings = repo_info.siblings
if repo_siblings is not None:
return [f.rfilename for f in repo_siblings]
else:
raise ValueError("No repo siblings found.")
except Exception as e:
logging.error(f"Error connecting to the Hub: {e}.")
@pytest.mark.xfail(
condition=os.environ.get("GITHUB_ACTIONS") == "true",
reason="There is something weird when writing to the Hub from a GitHub CI.",
strict=True,
)
def test_kernel_upload_deletes_as_expected():
repo_filenames = get_filenames_from_a_repo(REPO_ID)
filename_to_change = get_filename_to_change(repo_filenames)
with tempfile.TemporaryDirectory() as tmpdir:
path = f"{tmpdir}/build/torch-universal/upload_test"
build_dir = Path(path)
build_dir.mkdir(parents=True, exist_ok=True)
changed_filename = next_filename(Path(filename_to_change))
script_path = build_dir / changed_filename
script_path.write_text(PY_CONTENT)
upload_kernels(UploadArgs(tmpdir, REPO_ID, False))
repo_filenames = get_filenames_from_a_repo(REPO_ID)
assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}"
assert not any(
str(filename_to_change) in k for k in repo_filenames
), f"{repo_filenames=}"