mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 21:10:02 +08:00
Improve errors for layer validation (#145)
* Improve errors for layer validation Include the repo and layer name as well as the name of the class that is being compared to (when applicable). * Remove upload xfail * Only enable tests that require a token with `--token`
This commit is contained in:
@ -4,3 +4,4 @@ markers =
|
||||
rocm_only: marks tests that should only run on hosts with ROCm GPUs
|
||||
darwin_only: marks tests that should only run on macOS
|
||||
xpu_only: marks tests that should only run on hosts with Intel XPUs
|
||||
token: enable tests that require a write token
|
||||
|
@ -316,7 +316,7 @@ class LayerRepository:
|
||||
return hash((self.layer_name, self._repo_id, self._revision, self._version))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"`{self._repo_id}` (revision: {self._resolve_revision()}) for layer `{self.layer_name}`"
|
||||
return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`"
|
||||
|
||||
|
||||
class LocalLayerRepository:
|
||||
@ -372,7 +372,7 @@ class LocalLayerRepository:
|
||||
return hash((self.layer_name, self._repo_path, self._package_name))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"`{self._repo_path}` (package: {self._package_name}) for layer `{self.layer_name}`"
|
||||
return f"`{self._repo_path}` (package: {self._package_name}), layer `{self.layer_name}`"
|
||||
|
||||
|
||||
class LockedLayerRepository:
|
||||
@ -427,7 +427,7 @@ class LockedLayerRepository:
|
||||
return hash((self.layer_name, self._repo_id))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"`{self._repo_id}` (revision: {self._resolve_revision()}) for layer `{self.layer_name}`"
|
||||
return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`"
|
||||
|
||||
|
||||
_CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {}
|
||||
@ -1020,7 +1020,7 @@ def _get_kernel_layer(repo: LayerRepositoryProtocol) -> Type["nn.Module"]:
|
||||
return layer
|
||||
|
||||
|
||||
def _validate_layer(*, check_cls, cls):
|
||||
def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
|
||||
import torch.nn as nn
|
||||
|
||||
# The layer must have at least have the following properties: (1) it
|
||||
@ -1029,12 +1029,12 @@ def _validate_layer(*, check_cls, cls):
|
||||
# methods.
|
||||
|
||||
if not issubclass(cls, nn.Module):
|
||||
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
|
||||
raise TypeError(f"Layer `{cls.__name__}` is not a Torch layer.")
|
||||
|
||||
# We verify statelessness by checking that the does not have its own
|
||||
# constructor (since the constructor could add member variables)...
|
||||
if cls.__init__ is not nn.Module.__init__:
|
||||
raise TypeError("Layer must not override nn.Module constructor.")
|
||||
raise TypeError(f"{repo} must not override nn.Module constructor.")
|
||||
|
||||
# ... or predefined member variables.
|
||||
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
||||
@ -1042,7 +1042,9 @@ def _validate_layer(*, check_cls, cls):
|
||||
difference = cls_members - torch_module_members
|
||||
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
|
||||
if not difference <= {"can_torch_compile", "has_backward"}:
|
||||
raise TypeError("Layer must not contain additional members.")
|
||||
raise TypeError(
|
||||
f"{repo} must not contain additional members compared to `{check_cls.__name__}`."
|
||||
)
|
||||
|
||||
# Check whether the forward signatures are similar.
|
||||
params = inspect.signature(cls.forward).parameters
|
||||
@ -1050,13 +1052,13 @@ def _validate_layer(*, check_cls, cls):
|
||||
|
||||
if len(params) != len(ref_params):
|
||||
raise TypeError(
|
||||
"Forward signature does not match: different number of arguments."
|
||||
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different number of arguments."
|
||||
)
|
||||
|
||||
for param, ref_param in zip(params.values(), ref_params.values()):
|
||||
if param.kind != ref_param.kind:
|
||||
raise TypeError(
|
||||
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
|
||||
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
|
||||
)
|
||||
|
||||
|
||||
@ -1173,7 +1175,7 @@ def _get_layer_memoize(
|
||||
return layer
|
||||
|
||||
layer = _get_kernel_layer(repo)
|
||||
_validate_layer(check_cls=module_class, cls=layer)
|
||||
_validate_layer(check_cls=module_class, cls=layer, repo=repo)
|
||||
_CACHED_LAYER[repo] = layer
|
||||
|
||||
return layer
|
||||
|
@ -20,6 +20,14 @@ has_xpu = (
|
||||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--token",
|
||||
action="store_true",
|
||||
help="run tests that require a token with write permissions",
|
||||
)
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
if "cuda_only" in item.keywords and not has_cuda:
|
||||
pytest.skip("skipping CUDA-only test on host without CUDA")
|
||||
@ -29,3 +37,5 @@ def pytest_runtest_setup(item):
|
||||
pytest.skip("skipping macOS-only test on non-macOS platform")
|
||||
if "xpu_only" in item.keywords and not has_xpu:
|
||||
pytest.skip("skipping XPU-only test on host without XPU")
|
||||
if "token" in item.keywords and not item.config.getoption("--token"):
|
||||
pytest.skip("need --token option to run this test")
|
||||
|
@ -67,11 +67,7 @@ def get_filenames_from_a_repo(repo_id: str) -> List[str]:
|
||||
logging.error(f"Error connecting to the Hub: {e}.")
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="There is something weird when writing to the Hub from a GitHub CI.",
|
||||
strict=True,
|
||||
)
|
||||
@pytest.mark.token
|
||||
def test_kernel_upload_deletes_as_expected():
|
||||
repo_filenames = get_filenames_from_a_repo(REPO_ID)
|
||||
filename_to_change = get_filename_to_change(repo_filenames)
|
||||
|
@ -480,26 +480,43 @@ def test_validate_kernel_layer():
|
||||
super().__init__(*args, **kwargs)
|
||||
self.foo = 42
|
||||
|
||||
with pytest.raises(TypeError, match="not override"):
|
||||
_validate_layer(cls=BadLayer, check_cls=SiluAndMul)
|
||||
def stub_repo(layer):
|
||||
return LayerRepository(
|
||||
repo_id="kernels-test/nonexisting", layer_name=layer.__name__
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="`kernels-test/nonexisting`.*layer `BadLayer` must not override",
|
||||
):
|
||||
_validate_layer(cls=BadLayer, check_cls=SiluAndMul, repo=stub_repo(BadLayer))
|
||||
|
||||
class BadLayer2(nn.Module):
|
||||
foo: int = 42
|
||||
|
||||
with pytest.raises(TypeError, match="not contain additional members"):
|
||||
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul)
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="`kernels-test/nonexisting`.*layer `BadLayer2` must not contain.*SiluAndMul",
|
||||
):
|
||||
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul, repo=stub_repo(BadLayer2))
|
||||
|
||||
class BadLayer3(nn.Module):
|
||||
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
|
||||
|
||||
with pytest.raises(TypeError, match="different number of arguments"):
|
||||
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul)
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer3` does not match `SiluAndMul`: different number of arguments",
|
||||
):
|
||||
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul, repo=stub_repo(BadLayer3))
|
||||
|
||||
class BadLayer4(nn.Module):
|
||||
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
with pytest.raises(TypeError, match="different kind of arguments"):
|
||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer4` does not match `SiluAndMul`: different kind of arguments",
|
||||
):
|
||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul, repo=stub_repo(BadLayer4))
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
|
Reference in New Issue
Block a user