<< Back to posts

AI Tool Review - Outlines library for controlling LLM outputs

Posted on December 29, 2023 • Tags: llms ai ml hallucinations outlines python ai tool review

Outlines is a Python library which guarantees that the output of an LLM will conform to some expected format.

I really like this library, since avoiding hallucinations is a key challenge in many production settings.

Like most other LLM-guidance libraries, this library works by biasing the logits output by the model in such a way that tokens which would violate the user-specified schema are prevented from being sampled.

Outlines also works with vLLM, which is a super fast inference engine.

In this post, I’ll show you how to use Outlines to force a HuggingFace model to generate valid JSON and do multiple choice selection.

Setup

To install:

pip install outlines

Use Cases

1. JSON Output

You have two options here.

You can either define your JSON schema in terms of Pydantic classes:

from pydantic import BaseModel
import outlines

prompt: str = 'The patient is diabetic. Do they meet the criteria "abnormal A1C"?'

# Define our PyDantic class
class ExpectedJSONOutputFormat(BaseModel):
    rationale: str
    is_match: bool

# Load model via Outlines
model = outlines.models.transformers('gpt2', device='cuda')

# Enforce JSON output that conforms to `ExpectedJSONOutputFormat`
generator = outlines.generate.json(model, ExpectedJSONOutputFormat, max_tokens=500)

# Use the `generator` to sample an output from the model
result: ExpectedJSONOutputFormat = generator(prompt)

print(result)
# { 'rationale' : 'The patient says that they are diabetic', is_match: True }

Or, you can simply supply your JSON schema as a string:

schema: str = '''{
    "rationale": "ExpectedJSONOutputFormat",
    "type": "object",
    "properties": {
        "rationale": {
            "title": "Rationale",
            "type": "string"
        },
        "is_match": {
            "title": "Is Match",
            "type": "bool"
        },
    },
    "required": ["is_match", "rationale"]
}'''

# Enforce JSON output that conforms to `schema`
generator = outlines.generate.json(model, schema)

# Use the `generator` to sample an output from the model
result: Dict[str, Any] = generator(prompt)

print(result)
# { 'rationale' : 'The patient says that they are diabetic', is_match: True }

2. Multiple Choice

Forcing the model to choose from a set of options is super simple:

import outlines

prompt: str = 'Are you happy or sad?'
options: List[str] = ["Positive", "Negative"]

# Load model via Outlines
model = outlines.models.transformers('gpt2', device='cuda')

# Generate an answer from `options`
answer = outlines.generate.choice(model, options)(prompt)

print(answer)
# 'Positive'

3. Others

With Outlines, you can force the output of your LLM to conform to any regex or context-free grammar.

There are tons of other examples on the official docs!

vLLM integration

Outlines supports vLLM per the tutorial here.

For some reason, however, I couldn’t get it working with a multi-GPU setup (e.g. setting tensor_parallel_size > 0 for vllm).

In order to get it working, I had to explicitly set use_beam_search=False and n=1 as arguments to the SamplingParams constructor.

Here is a code snippet for using vLLM with Outlines within your code (i.e. not using the API server):

outlines.py

from typing import Union, List, Optional
import json
import math
from collections import defaultdict
from typing import DefaultDict, List
import torch
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object

class RegexLogitsProcessor:
    def __init__(self, regex_string, llm):
        tokenizer = self.adapt_tokenizer(llm.get_tokenizer())
        fsm = RegexFSM(regex_string, tokenizer)
        self.fsm = fsm
        self.fsm_state: DefaultDict[int, int] = defaultdict(int)
    def __call__(
        self, seq_id: int, input_ids: List[int], scores: torch.Tensor
    ) -> torch.Tensor:
        """Use the FSM to bias the logits before sampling the next token."""
        if len(input_ids) == 0:  # Initialize the fsm states
            self.fsm_state: DefaultDict[int, int] = defaultdict(int)
        else:
            last_token = input_ids[-1]
            self.fsm_state[seq_id] = self.fsm.next_state(
                self.fsm_state[seq_id], last_token
            )
        allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
        mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
        mask[allowed_tokens] = 0
        biased_scores = scores + mask
        return biased_scores
    def adapt_tokenizer(self, tokenizer):
        tokenizer.vocabulary = tokenizer.get_vocab()
        tokenizer.special_tokens = set(tokenizer.all_special_tokens)
        def convert_token_to_string(token: str) -> str:
            from transformers.file_utils import SPIECE_UNDERLINE
            string = tokenizer.convert_tokens_to_string([token])
            # A hack to handle missing spaces to HF's Llama tokenizers
            if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
                return " " + string
            return string
        tokenizer.convert_token_to_string = convert_token_to_string
        return tokenizer

class JSONLogitsProcessor(RegexLogitsProcessor):
    def __init__(self, schema, llm):
        if isinstance(schema, dict):
            schema = json.dumps(schema)
        regex_string = build_regex_from_object(schema)
        self.fsm_state: DefaultDict[int, int] = defaultdict(int)
        super().__init__(regex_string, llm)
    def __call__(
        self, seq_id: int, input_ids: List[int], scores: torch.Tensor
    ) -> torch.Tensor:
        val = super().__call__(seq_id, input_ids, scores)
        return val

# NOTE: Need to copy this into vllm.model_executor.layers.sampler._apply_logits_processors
def _patched_apply_logits_processors(
    logits,
    sampling_metadata,
):
    logits_row_idx = 0
    found_logits_processors = False
    for seq_ids, sampling_params in sampling_metadata.seq_groups:
        logits_processors = sampling_params.logits_processors
        if logits_processors:
            found_logits_processors = True
            for seq_id in seq_ids:
                logits_row = logits[logits_row_idx]
                token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
                for logits_processor in logits_processors:
                    logits_row = logits_processor(seq_id, token_ids, logits_row)
                logits[logits_row_idx] = logits_row
                logits_row_idx += 1
        else:
            logits_row_idx += len(seq_ids)
    if found_logits_processors:
        assert logits_row_idx == logits.shape[0]
    return logits

script.py

import vllm

prompts: List[str] = [ "The sky is", "Hello, my name" ]
n_gpus: int = 4 # number of GPUs to use

# Generate responses with vLLM
model = vllm.LLM(model='gpt2', tensor_parallel_size=n_gpus)
sampling_params = SamplingParams(
  max_tokens=4096, 
  logits_processors=[
    JSONLogitsProcessor(ExpectedJSONOutputFormat.schema(), model)
 	], 
  # IMPORTANT -- keep the below to get multi-GPU working:
  use_beam_search=False, 
  n=1
)
raw_responses: List[Dict[str, Any]] = model.generate(prompts, sampling_params)

# Parse responses
results = []
for raw_response in raw_responses:
    response: Dict[str, str] = json.loads(raw_response.outputs[0].text)
    result = ExpectedJSONOutputFormat(**response)
    # Get usage stats
    completion_tokens: int =len(raw_response.prompt_token_ids)
    prompt_tokens: int = len(raw_response.outputs[0].token_ids)
    results.append((result, stat))

Takeaways

Strengths

  1. Easy to use. I like that the interface is super simple and straightforward. It readily integrates into any stack that uses HuggingFace models.
  2. Guaranteed Output I also really like that Outlines guarantees the output of your model will conform to your schema, versus other libraries that rely on retrying / parsing / hoping for the best.
  3. Speed. Outlines is supposed to be faster than alternatives such as Microsoft’s Guidance or LMQL because it only loops over the vocabulary once at the start, rather than at each step of generation
  4. Flexibility. Outlines supports arbitrary Regex-specified rulesets, so you can specify a very wide range of output formats.
  5. vLLM integration. Has an officially supported integration with vLLM, a super fast inference engine.

Limitations

  1. Limited support for closed models like OpenAI. Proprietary models like OpenAI, Anthropic, Gemini, nd Cohere are not fully compatible since we cannot access the model’s raw logits.
    1. However, Outlines does support multiple choice querying of OpenAI models like GPT-4 via the model.generate_chat() function. My understanding of how it works is by successively querying OpenAI one token at a time, and masking the next output token such that only tokens corresponding to valid answer choices are possible to be sampled (i.e. by assigning them a bias of 100). Then, it adds whatever token was generated back to the prompt, and repeats until a valid answer is generated.

Alternatives

There are many other libraries that support logit-based control over LLM outputs.

Below, I’ve copied the table from the LM-Format-Enforcer README.

I’ve only tried LM Format Enforcer from this list, and even with its own vLLM integration it was significantly slower than Outlines.

So take the rest of this table with a grain of salt as I haven’t verified it (obviously LM-Format-Enforcer comes out on top), but hopefully this gives a sense of other libraries’ capabilities:

Capability LM Format Enforcer Guidance Jsonformer Outlines
Regular Expressions
JSON Schema 🟡 (Partial conversion is possible)
Batched Generation
Beam Search
Integrates into existing pipelines
Optional JSON Fields
LLM Controls JSON field ordering and whitespace

Others:

References