Back to Perplexity questions
CodingMachine Learning Engineer

Beam Search Decoding

Role: Machine Learning Engineer


Problem Overview

Implement beam search, a decoding algorithm widely used in sequence-to-sequence models (machine translation, text generation, speech recognition). Given a starting token sequence, a function that returns the next-token probability distribution, and search parameters, produce the top-scoring output sequences.

The interviewer provides the function signature. You must complete the implementation, and your code must pass unit tests.

Given Signature:

python
from typing import List, Callable

def beam_search(
    input_seq: List[int],
    next_token_fn: Callable[[List[int]], List[float]],
    max_token: int,
    beam_size: int,
    stop_word_id: int
) -> List[List[int]]:
    pass

Parameters:

  • input_seq: Initial token sequence (the prompt). The generated tokens are appended after this prefix.

  • next_token_fn: Takes a token sequence and returns a probability distribution over the vocabulary. next_token_fn(seq)[i] is the probability of token i being the next token.

  • max_token: Maximum number of new tokens to generate.

  • beam_size: Number of candidate sequences (beams) to keep at each step.

  • stop_word_id: Token ID that signals end of generation. A sequence is "completed" once it generates this token.

Returns:

  • List of sequences (each including the input_seq prefix), sorted by cumulative log-probability from highest to lowest. Completed sequences (ending with stop_word_id) are preferred; if none exist, return the best active beams.

Example:

python
# Vocabulary: {0: 'hello', 1: 'world', 2: '<EOS>'}
def simple_next_token(seq):
    return [0.2, 0.5, 0.3]

result = beam_search(
    input_seq=[0],
    next_token_fn=simple_next_token,
    max_token=2,
    beam_size=2,
    stop_word_id=2
)
# result: [[0, 2], [0, 1, 2], [0, 0, 2]]
#   [0, 2]:       score = log(0.3) ≈ -1.20
#   [0, 1, 2]:    score = log(0.5) + log(0.3) ≈ -1.90
#   [0, 0, 2]:    score = log(0.2) + log(0.3) ≈ -2.81

Key Requirement: Your implementation must pass all provided unit tests.


Part 1: Core Algorithm

StrategyCandidates KeptQualityCost
Greedy1 (best at each step)Can miss globally optimal pathsO(T×V)O(T \times V)O(T×V)
Beam SearchBBB (top-B at each step)Balances quality and costO(T×B×V)O(T \times B \times V)O(T×B×V)
ExhaustiveAllOptimalO(VT)O(V^T)O(VT) (intractable)

Why Greedy Can Fail

Consider a scenario where:

  • Token A has probability 0.4 at step 1, but leads to a stop token with probability 0.9

  • Token B has probability 0.5 at step 1, but leads to a stop token with probability 0.3

Greedy picks B (0.5 > 0.4), yielding a total score of log⁡(0.5)+log⁡(0.3)≈−1.90\log(0.5) + \log(0.3) \approx -1.90log(0.5)+log(0.3)≈−1.90. Beam search with width ≥2\geq 2≥2 also explores path A, finding a score of log⁡(0.4)+log⁡(0.9)≈−1.02\log(0.4) + \log(0.9) \approx -1.02log(0.4)+log(0.9)≈−1.02 — a significantly better sequence.

Algorithm Steps

  • Initialize: One beam containing input_seq with score 0.0

  • For each generation step (up to max_token):

  • Expand: For each active beam, get next_token_fn(seq) and create a candidate for each token

  • Score: Each candidate's score = parent's score + log⁡(ptoken)\log(p_{\text{token}})log(ptoken​)

  • Separate: Candidates ending with stop_word_id go to completed list

  • Prune: Keep top beam_size active candidates (by score)

  • Return: Completed sequences sorted by score; if none, use active beams

Why Log Probability?

We accumulate log-probability instead of raw probability:

  • Avoids numerical underflow from multiplying many small values

  • log⁡(p1×p2×⋯×pn)=log⁡(p1)+log⁡(p2)+⋯+log⁡(pn)\log(p_1 \times p_2 \times \cdots \times p_n) = \log(p_1) + \log(p_2) + \cdots + \log(p_n)log(p1​×p2​×⋯×pn​)=log(p1​)+log(p2​)+⋯+log(pn​)

  • Higher (less negative) log-probability = better sequence


Part 2: Implementation

python
import math
from typing import List, Callable

def beam_search(
    input_seq: List[int],
    next_token_fn: Callable[[List[int]], List[float]],
    max_token: int,
    beam_size: int,
    stop_word_id: int
) -> List[List[int]]:
    """
    Beam search decoding for sequence generation.

    Args:
        input_seq: Initial token sequence (prompt)
        next_token_fn: Returns probability distribution over next tokens
        max_token: Maximum number of new tokens to generate
        beam_size: Number of beams to maintain
        stop_word_id: Token ID that signals end of generation

    Returns:
        Completed sequences sorted by cumulative log-probability (highest first)
    """
    # Each beam: (sequence, cumulative_log_prob)
    beams = [(input_seq[:], 0.0)]
    completed = []

    for _ in range(max_token):
        all_candidates = []

        for seq, score in beams:
            # Get probability distribution for next token
            probs = next_token_fn(seq)

            # Expand this beam with each possible next token
            for token_id, prob in enumerate(probs):
                if prob <= 0:
                    continue
                new_seq = seq + [token_id]
                new_score = score + math.log(prob)

                # Separate completed from active candidates
                if token_id == stop_word_id:
                    completed.append((new_seq, new_score))
                else:
                    all_candidates.append((new_seq, new_score))

        # If no active candidates remain, stop
        if not all_candidates:
            break

        # Keep only top beam_size active candidates
        all_candidates.sort(key=lambda x: x[1], reverse=True)
        beams = all_candidates[:beam_size]

    # If no sequence completed with stop word, return best active beams
    if not completed:
        completed = beams

    # Sort by score (highest first) and return sequences only
    completed.sort(key=lambda x: x[1], reverse=True)
    return [seq for seq, _ in completed]

Walkthrough

Using input_seq=[0], beam_size=2, max_token=2, stop_word_id=2:

python
def next_token_fn(seq):
    if len(seq) >= 2:
        return [0.0, 0.0, 1.0]  # force stop
    return [0.3, 0.5, 0.2]

Step 1 — expand [0]:

CandidateScoreStatus
[0, 1]log⁡(0.5)=−0.69\log(0.5) = -0.69log(0.5)=−0.69Active
[0, 0]log⁡(0.3)=−1.20\log(0.3) = -1.20log(0.3)=−1.20Active
[0, 2]log⁡(0.2)=−1.61\log(0.2) = -1.61log(0.2)=−1.61Completed

Active beams (top 2): [0, 1], [0, 0]

Step 2 — expand [0, 1] and [0, 0] (both get [0.0, 0.0, 1.0]):

CandidateScoreStatus
[0, 1, 2]−0.69+log⁡(1.0)=−0.69-0.69 + \log(1.0) = -0.69−0.69+log(1.0)=−0.69Completed
[0, 0, 2]−1.20+log⁡(1.0)=−1.20-1.20 + \log(1.0) = -1.20−1.20+log(1.0)=−1.20Completed

No active candidates → stop.

Final result (sorted): [[0, 1, 2], [0, 0, 2], [0, 2]]


Part 3: Unit Tests

Your implementation must pass all of the following tests.

python
import math

def test_immediate_stop():
    """When stop word has highest probability, generate it immediately."""
    def next_token_fn(seq):
        return [0.1, 0.1, 0.8]

    result = beam_search([0], next_token_fn, max_token=5, beam_size=3, stop_word_id=2)
    assert result[0] == [0, 2], f"Expected [0, 2], got {result[0]}"

def test_greedy_search():
    """beam_size=1 should behave like greedy search."""
    def next_token_fn(seq):
        if len(seq) >= 3:
            return [0.1, 0.1, 0.8]  # generate stop after 2 tokens
        return [0.1, 0.7, 0.2]      # token 1 is best

    result = beam_search([0], next_token_fn, max_token=5, beam_size=1, stop_word_id=2)
    assert result[0] == [0, 1, 1, 2], f"Expected [0, 1, 1, 2], got {result[0]}"

def test_max_token_limit():
    """Generation stops at max_token even without stop word."""
    def next_token_fn(seq):
        return [0.8, 0.2, 0.0]  # stop word never generated

    result = beam_search([5], next_token_fn, max_token=3, beam_size=1, stop_word_id=2)
    assert len(result[0]) == 4, f"Expected length 4 (1 input + 3 generated), got {len(result[0])}"
    assert result[0] == [5, 0, 0, 0], f"Expected [5, 0, 0, 0], got {result[0]}"

def test_multiple_completed_sequences():
    """Should return multiple completed sequences when beam_size > 1."""
    def next_token_fn(seq):
        if len(seq) >= 2:
            return [0.0, 0.0, 1.0]  # force stop at step 2
        return [0.3, 0.5, 0.2]

    result = beam_search([0], next_token_fn, max_token=3, beam_size=2, stop_word_id=2)
    assert len(result) >= 2, f"Expected at least 2 sequences, got {len(result)}"
    assert result[0] == [0, 1, 2], f"Expected [0, 1, 2], got {result[0]}"
    assert result[1] == [0, 0, 2], f"Expected [0, 0, 2], got {result[1]}"

def test_beam_search_advantage():
    """
    Beam search finds a better completed sequence than greedy.

    Path A (token 0): 0.4 → STOP 0.9 → score = log(0.4)+log(0.9) ≈ -1.02
    Path B (token 1): 0.5 → STOP 0.3 → score = log(0.5)+log(0.3) ≈ -1.90

    Greedy picks B first, missing the globally better A→STOP path.
    """
    def next_token_fn(seq):
        last = seq[-1]
        if last == 5:
            return [0.4, 0.5, 0.1]
        elif last == 0:
            return [0.05, 0.05, 0.9]
        elif last == 1:
            return [0.05, 0.65, 0.3]
        return [0.0, 0.0, 1.0]

    beam = beam_search([5], next_token_fn, max_token=2, beam_size=2, stop_word_id=2)
    beam_completed = [s for s in beam if s[-1] == 2]

    assert [5, 0, 2] in beam_completed, f"Expected [5, 0, 2] in results, got {beam_completed}"
    # [5, 0, 2] should be ranked higher than [5, 1, 2]
    if [5, 1, 2] in beam_completed:
        assert beam_completed.index([5, 0, 2]) < beam_completed.index([5, 1, 2])

def test_sorted_by_score():
    """Results should be sorted by cumulative log-probability, highest first."""
    def next_token_fn(seq):
        if len(seq) == 1:
            return [0.3, 0.5, 0.2]
        return [0.0, 0.0, 1.0]

    result = beam_search([0], next_token_fn, max_token=2, beam_size=3, stop_word_id=2)
    # [0, 1, 2]: log(0.5) + log(1.0) = -0.69
    # [0, 0, 2]: log(0.3) + log(1.0) = -1.20
    # [0, 2]:    log(0.2) = -1.61
    assert result == [[0, 1, 2], [0, 0, 2], [0, 2]], f"Unexpected order: {result}"

# Run all tests
if __name__ == "__main__":
    test_immediate_stop()
    test_greedy_search()
    test_max_token_limit()
    test_multiple_completed_sequences()
    test_beam_search_advantage()
    test_sorted_by_score()
    print("All tests passed!")

Complexity Analysis

Time Complexity: O(T×B×V×log⁡(B×V))O(T \times B \times V \times \log(B \times V))O(T×B×V×log(B×V))

  • TTT = max_token (generation steps)

  • BBB = beam_size (beams expanded per step)

  • VVV = vocabulary size (candidates per beam)

  • log⁡(B×V)\log(B \times V)log(B×V) from sorting candidates

Space Complexity: O(B×T+C×T)O(B \times T + C \times T)O(B×T+C×T)

  • B×TB \times TB×T for active beam sequences

  • C×TC \times TC×T for completed sequences (where CCC = number of completed beams)

  • O(B×V)O(B \times V)O(B×V) transient space for candidates per step


Key Discussion Points in Interview

1. Log Probability vs. Raw Probability

Question: "Why not just multiply probabilities?"

Answer: Floating-point underflow. For a 100-token sequence with average token probability 0.1:

  • Raw: 0.1100=10−1000.1^{100} = 10^{-100}0.1100=10−100 → underflows to 0.0

  • Log: 100×log⁡(0.1)=−230.3100 \times \log(0.1) = -230.3100×log(0.1)=−230.3 → perfectly representable

2. Length Normalization

Question: "Beam search tends to prefer shorter sequences. Why, and how do you fix it?"

Answer: Cumulative log-probability is always negative, so longer sequences accumulate more negative scores. Length-normalized scoring divides by sequence length:

scorenorm=1∣y∣α∑t=1∣y∣log⁡p(yt∣y<t)\text{score}{\text{norm}} = \frac{1}{|y|^\alpha} \sum{t=1}^{|y|} \log p(y_t \mid y_{<t})scorenorm​=∣y∣α1​∑t=1∣y∣​logp(yt​∣y<t​)

where α∈[0,1]\alpha \in [0, 1]α∈[0,1] controls normalization strength (α=0\alpha = 0α=0: no normalization, α=1\alpha = 1α=1: full normalization). Google's NMT paper uses α=0.6\alpha = 0.6α=0.6.

3. Heap Optimization for Pruning

Question: "How would you optimize the candidate selection step?"

Answer: Instead of sorting all B×VB \times VB×V candidates, use a min-heap of size BBB:

python
import heapq

# O(B*V * log(B)) instead of O(B*V * log(B*V))
top_beams = heapq.nlargest(beam_size, all_candidates, key=lambda x: x[1])

For large vocabularies (V≫BV \gg BV≫B), this is significantly faster.

Question: "All beams tend to produce very similar sequences. How do you encourage diversity?"

Answer:

  • Group-based beam search: Divide beams into groups, add dissimilarity penalty between groups

  • Top-k / Top-p sampling before beam search: Restrict candidate tokens to reduce redundancy

  • Hamming diversity penalty: Penalize beams that share tokens at the same position

5. Beam Search in Production (LLM Inference)

Question: "How does beam search interact with KV cache in transformer inference?"

Answer:

  • Each beam maintains its own KV cache (key-value pairs from attention layers)

  • When a beam is pruned, its cache is discarded

  • When a beam is kept, its cache is reused for the next step

  • Memory usage: O(B×L×H)O(B \times L \times H)O(B×L×H) where LLL = sequence length, HHH = hidden dimension

  • This is why large beam sizes are expensive in transformer inference