implement Structural Tag with Guidance backend (#17333)

Signed-off-by: Michal Moskal <michal@moskal.me>
This commit is contained in:
Michał Moskal
2025-04-28 19:21:32 -07:00
committed by GitHub
parent 506475de5f
commit 86d9fc29cb
2 changed files with 32 additions and 10 deletions

View File

@ -435,13 +435,10 @@ Given the previous instructions, what is the weather in New York City?
"""
# Change this once other backends support structural_tag
if guided_decoding_backend.startswith("xgrammar"):
outputs = llm.generate(prompts=prompt,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
else:
outputs = []
outputs = llm.generate(prompts=prompt,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None

View File

@ -173,7 +173,8 @@ def serialize_guidance_grammar(
disable_any_whitespace: bool = False,
no_additional_properties: bool = False,
) -> str:
if request_type == StructuredOutputOptions.JSON:
def _process_schema(grammar_spec: Union[str, dict[str, Any]], ) -> str:
if no_additional_properties:
grammar_spec = process_for_additional_properties(grammar_spec)
return llguidance.LLMatcher.grammar_from_json_schema(
@ -181,6 +182,9 @@ def serialize_guidance_grammar(
defaults={
"whitespace_flexible": not disable_any_whitespace,
})
if request_type == StructuredOutputOptions.JSON:
return _process_schema(grammar_spec)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
return llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}',
@ -195,8 +199,29 @@ def serialize_guidance_grammar(
elif request_type == StructuredOutputOptions.CHOICE:
tp = "choice"
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
raise ValueError("Structural tag is not supported "
"for guidance backend yet")
if isinstance(grammar_spec, str):
s_tag = json.loads(grammar_spec)
else:
s_tag = grammar_spec
triggers: list[str] = s_tag["triggers"]
tags: list[llguidance.StructTag] = []
for s in s_tag["structures"]:
begin: str = s["begin"]
trig = next((t for t in triggers if begin.startswith(t)), None)
if trig is None:
raise ValueError(
f"Trigger {begin} not found in triggers {triggers}")
tags.append(
llguidance.StructTag(
trigger=trig,
begin=s["begin"],
grammar=_process_schema(s["schema"]),
end=s["end"],
))
if not tags:
raise ValueError(
"No structural tags found in the grammar spec.")
return llguidance.StructTag.to_grammar(tags)
else:
logger.error("Validation should have already occurred. "
"Please file an issue.")