<< Back to posts

AI Tool Review - LM-Format-Enforcer for controlling LLM outputs

Posted on January 9, 2024 • Tags: llms ai ml hallucinations lm-format-enforcer python ai tool review

LM Format Enforcer is a Python library which guarantees that the output of an LLM will conform to some expected format.

Like most other LLM-guidance libraries, this library works by biasing the logits output by the model in such a way that tokens which would violate the user-specified schema are prevented from being sampled.

I previously reviewed a similar tool called Outlines.

Overall I still prefer Outlines, but LM Format Enforcer has more features, so I wanted to give it a try.

Setup

To install:

pip install lm-format-enforcer vllm

Use Cases

1. JSON Output with vLLM

Define your JSON schema in terms of Pydantic classes:

from pydantic import BaseModel

prompt: str = 'The patient is diabetic. Do they meet the criteria "abnormal A1C"?'

# Define our PyDantic class
class ExpectedJSONOutputFormat(BaseModel):
    rationale: str
    is_match: bool

Then run LM Format Enforcer with vLLM:

from lmformatenforcer.integrations.vllm import build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data
from typing import Union, List, Optional
from vllm import SamplingParams
from lmformatenforcer import JsonSchemaParser, CharacterLevelParser

def vllm_with_character_level_parser(model, prompts: Union[str, List[str]], max_tokens: int = 4096, parser: Optional[CharacterLevelParser] = None) -> Union[str, List[str]]:
    tokenizer_data = build_vllm_token_enforcer_tokenizer_data(model)
    sampling_params = SamplingParams()
    sampling_params.max_tokens = max_tokens
    if parser:
        logits_processor = build_vllm_logits_processor(tokenizer_data, parser)
        sampling_params.logits_processors = [logits_processor]
    results = model.generate(prompts, sampling_params=sampling_params)
    if isinstance(prompts, str):
        return results[0].outputs[0].text
    else:
        return [result.outputs[0].text for result in results]

schema = ExpectedJSONOutputFormat.schema()
results: str = vllm_with_character_level_parser(model, prompt, JsonSchemaParser(schema))

Note: For batched generation, it was faster to actually loop through the prompts and individually call .generate() rather than pass all the prompts at once to .generate()

Takeaways

Strengths

  1. vLLM Support Can confirm that LM Format Enforcer also supports vLLM. I found it to be a bit slower than using Outlines though.

Other Strengths / Limitations are basically what was provided in my Outlines writeup

References