[Misc] Validate grammar and fail early (#11119)

This commit is contained in:
Cody Yu
2024-12-12 10:34:26 -08:00
committed by GitHub
parent 5d712571af
commit 2c97eca1ff
2 changed files with 26 additions and 18 deletions

View File

@ -131,22 +131,25 @@ class GrammarConfig:
max_threads: int = 8) -> GrammarConfig:
tokenizer_hash = hash(tokenizer)
# Only get tokenizer data if not already cached
if tokenizer_hash in TokenizerDataCache._cache:
encoded_vocab = None
stop_token_ids = None
backend_str = None
else:
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
encoded_vocab = tokenizer_data.encoded_vocab
stop_token_ids = tokenizer_data.stop_token_ids
backend_str = tokenizer_data.backend_str
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
encoded_vocab = tokenizer_data.encoded_vocab
stop_token_ids = tokenizer_data.stop_token_ids
backend_str = tokenizer_data.backend_str
if guided_params.json:
if not isinstance(guided_params.json, str):
json_str = json.dumps(guided_params.json)
else:
json_str = guided_params.json
# Validate the schema and raise ValueError here if it is invalid.
# This is to avoid exceptions in model execution, which will crash
# the engine worker process.
try:
xgr.Grammar.from_json_schema(json_str)
except RuntimeError as err:
raise ValueError(str(err)) from err
return cls(json_str=json_str,
vocab_size=model_config.hf_text_config.vocab_size,
encoded_vocab=encoded_vocab,
@ -167,6 +170,15 @@ class GrammarConfig:
f"Conversion error: {str(e)}") from e
else:
grammar_str = guided_params.grammar
# Validate the grammar and raise ValueError here if it is invalid.
# This is to avoid exceptions in model execution, which will crash
# the engine worker process.
try:
xgr.Grammar.from_ebnf(grammar_str)
except RuntimeError as err:
raise ValueError(str(err)) from err
return cls(grammar_str=grammar_str,
vocab_size=model_config.hf_text_config.vocab_size,
encoded_vocab=encoded_vocab,

View File

@ -26,15 +26,11 @@ def grammar_is_likely_lark(grammar_str: str) -> bool:
if not line:
continue
# Look for Lark-style rule definitions
if ':' in line and '::=' not in line:
return True
# Look for GBNF rule definition
if '::=' in line:
return False
# Look for Lark-specific features
if any(pattern in line for pattern in ['?start:', '|', '~']):
return True
return False
return True
def convert_lark_to_gbnf(grammar_str: str) -> str: