AI Tool Review - LM-Format-Enforcer for controlling LLM outputs

Posted on January 9, 2024 • Tags: llms ai ml hallucinations lm-format-enforcer python ai tool review

LM Format Enforcer is a Python library which guarantees that the output of an LLM will conform to some expected format.

Like most other LLM-guidance libraries, this library works by biasing the logits output by the model in such a way that tokens which would violate the user-specified schema are prevented from being sampled.

I previously reviewed a similar tool called Outlines.

Overall I still prefer Outlines, but LM Format Enforcer has more features, so I wanted to give it a try.


To install:

pip install lm-format-enforcer vllm

Use Cases

1. JSON Output with vLLM

Define your JSON schema in terms of Pydantic classes:

from pydantic import BaseModel

prompt: str = 'The patient is diabetic. Do they meet the criteria "abnormal A1C"?'

# Define our PyDantic class
class ExpectedJSONOutputFormat(BaseModel):
    rationale: str
    is_match: bool

Then run LM Format Enforcer with vLLM:

from lmformatenforcer.integrations.vllm import build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data
from typing import Union, List, Optional
from vllm import SamplingParams
from lmformatenforcer import JsonSchemaParser, CharacterLevelParser

def vllm_with_character_level_parser(model, prompts: Union[str, List[str]], max_tokens: int = 4096, parser: Optional[CharacterLevelParser] = None) -> Union[str, List[str]]:
    tokenizer_data = build_vllm_token_enforcer_tokenizer_data(model)
    sampling_params = SamplingParams()
    sampling_params.max_tokens = max_tokens
    if parser:
        logits_processor = build_vllm_logits_processor(tokenizer_data, parser)
        sampling_params.logits_processors = [logits_processor]
    results = model.generate(prompts, sampling_params=sampling_params)
    if isinstance(prompts, str):
        return results[0].outputs[0].text
        return [result.outputs[0].text for result in results]

schema = ExpectedJSONOutputFormat.schema()
results: str = vllm_with_character_level_parser(model, prompt, JsonSchemaParser(schema))

Note: For batched generation, it was faster to actually loop through the prompts and individually call .generate() rather than pass all the prompts at once to .generate()



Other Strengths / Limitations are basically what was provided in my Outlines writeup