Beam Search Decoding
Role: Machine Learning Engineer
Problem Overview
Implement beam search, a decoding algorithm widely used in sequence-to-sequence models (machine translation, text generation, speech recognition). Given a starting token sequence, a function that returns the next-token probability distribution, and search parameters, produce the top-scoring output sequences.
The interviewer provides the function signature. You must complete the implementation, and your code must pass unit tests.
Given Signature:
from typing import List, Callable
def beam_search(
input_seq: List[int],
next_token_fn: Callable[[List[int]], List[float]],
max_token: int,
beam_size: int,
stop_word_id: int
) -> List[List[int]]:
passParameters:
-
input_seq: Initial token sequence (the prompt). The generated tokens are appended after this prefix. -
next_token_fn: Takes a token sequence and returns a probability distribution over the vocabulary.next_token_fn(seq)[i]is the probability of tokenibeing the next token. -
max_token: Maximum number of new tokens to generate. -
beam_size: Number of candidate sequences (beams) to keep at each step. -
stop_word_id: Token ID that signals end of generation. A sequence is "completed" once it generates this token.
Returns:
- List of sequences (each including the
input_seqprefix), sorted by cumulative log-probability from highest to lowest. Completed sequences (ending withstop_word_id) are preferred; if none exist, return the best active beams.
Example:
# Vocabulary: {0: 'hello', 1: 'world', 2: '<EOS>'}
def simple_next_token(seq):
return [0.2, 0.5, 0.3]
result = beam_search(
input_seq=[0],
next_token_fn=simple_next_token,
max_token=2,
beam_size=2,
stop_word_id=2
)
# result: [[0, 2], [0, 1, 2], [0, 0, 2]]
# [0, 2]: score = log(0.3) ≈ -1.20
# [0, 1, 2]: score = log(0.5) + log(0.3) ≈ -1.90
# [0, 0, 2]: score = log(0.2) + log(0.3) ≈ -2.81Key Requirement: Your implementation must pass all provided unit tests.
Part 1: Core Algorithm
Beam Search vs. Greedy vs. Exhaustive Search
| Strategy | Candidates Kept | Quality | Cost |
|---|---|---|---|
| Greedy | 1 (best at each step) | Can miss globally optimal paths | O(T×V)O(T \times V)O(T×V) |
| Beam Search | BBB (top-B at each step) | Balances quality and cost | O(T×B×V)O(T \times B \times V)O(T×B×V) |
| Exhaustive | All | Optimal | O(VT)O(V^T)O(VT) (intractable) |
Why Greedy Can Fail
Consider a scenario where:
-
Token A has probability 0.4 at step 1, but leads to a stop token with probability 0.9
-
Token B has probability 0.5 at step 1, but leads to a stop token with probability 0.3
Greedy picks B (0.5 > 0.4), yielding a total score of log(0.5)+log(0.3)≈−1.90\log(0.5) + \log(0.3) \approx -1.90log(0.5)+log(0.3)≈−1.90. Beam search with width ≥2\geq 2≥2 also explores path A, finding a score of log(0.4)+log(0.9)≈−1.02\log(0.4) + \log(0.9) \approx -1.02log(0.4)+log(0.9)≈−1.02 — a significantly better sequence.
Algorithm Steps
-
Initialize: One beam containing
input_seqwith score 0.0 -
For each generation step (up to
max_token): -
Expand: For each active beam, get
next_token_fn(seq)and create a candidate for each token -
Score: Each candidate's score = parent's score + log(ptoken)\log(p_{\text{token}})log(ptoken)
-
Separate: Candidates ending with
stop_word_idgo tocompletedlist -
Prune: Keep top
beam_sizeactive candidates (by score) -
Return: Completed sequences sorted by score; if none, use active beams
Why Log Probability?
We accumulate log-probability instead of raw probability:
-
Avoids numerical underflow from multiplying many small values
-
log(p1×p2×⋯×pn)=log(p1)+log(p2)+⋯+log(pn)\log(p_1 \times p_2 \times \cdots \times p_n) = \log(p_1) + \log(p_2) + \cdots + \log(p_n)log(p1×p2×⋯×pn)=log(p1)+log(p2)+⋯+log(pn)
-
Higher (less negative) log-probability = better sequence
Part 2: Implementation
import math
from typing import List, Callable
def beam_search(
input_seq: List[int],
next_token_fn: Callable[[List[int]], List[float]],
max_token: int,
beam_size: int,
stop_word_id: int
) -> List[List[int]]:
"""
Beam search decoding for sequence generation.
Args:
input_seq: Initial token sequence (prompt)
next_token_fn: Returns probability distribution over next tokens
max_token: Maximum number of new tokens to generate
beam_size: Number of beams to maintain
stop_word_id: Token ID that signals end of generation
Returns:
Completed sequences sorted by cumulative log-probability (highest first)
"""
# Each beam: (sequence, cumulative_log_prob)
beams = [(input_seq[:], 0.0)]
completed = []
for _ in range(max_token):
all_candidates = []
for seq, score in beams:
# Get probability distribution for next token
probs = next_token_fn(seq)
# Expand this beam with each possible next token
for token_id, prob in enumerate(probs):
if prob <= 0:
continue
new_seq = seq + [token_id]
new_score = score + math.log(prob)
# Separate completed from active candidates
if token_id == stop_word_id:
completed.append((new_seq, new_score))
else:
all_candidates.append((new_seq, new_score))
# If no active candidates remain, stop
if not all_candidates:
break
# Keep only top beam_size active candidates
all_candidates.sort(key=lambda x: x[1], reverse=True)
beams = all_candidates[:beam_size]
# If no sequence completed with stop word, return best active beams
if not completed:
completed = beams
# Sort by score (highest first) and return sequences only
completed.sort(key=lambda x: x[1], reverse=True)
return [seq for seq, _ in completed]Walkthrough
Using input_seq=[0], beam_size=2, max_token=2, stop_word_id=2:
def next_token_fn(seq):
if len(seq) >= 2:
return [0.0, 0.0, 1.0] # force stop
return [0.3, 0.5, 0.2]Step 1 — expand [0]:
| Candidate | Score | Status |
|---|---|---|
[0, 1] | log(0.5)=−0.69\log(0.5) = -0.69log(0.5)=−0.69 | Active |
[0, 0] | log(0.3)=−1.20\log(0.3) = -1.20log(0.3)=−1.20 | Active |
[0, 2] | log(0.2)=−1.61\log(0.2) = -1.61log(0.2)=−1.61 | Completed |
Active beams (top 2): [0, 1], [0, 0]
Step 2 — expand [0, 1] and [0, 0] (both get [0.0, 0.0, 1.0]):
| Candidate | Score | Status |
|---|---|---|
[0, 1, 2] | −0.69+log(1.0)=−0.69-0.69 + \log(1.0) = -0.69−0.69+log(1.0)=−0.69 | Completed |
[0, 0, 2] | −1.20+log(1.0)=−1.20-1.20 + \log(1.0) = -1.20−1.20+log(1.0)=−1.20 | Completed |
No active candidates → stop.
Final result (sorted): [[0, 1, 2], [0, 0, 2], [0, 2]]
Part 3: Unit Tests
Your implementation must pass all of the following tests.
import math
def test_immediate_stop():
"""When stop word has highest probability, generate it immediately."""
def next_token_fn(seq):
return [0.1, 0.1, 0.8]
result = beam_search([0], next_token_fn, max_token=5, beam_size=3, stop_word_id=2)
assert result[0] == [0, 2], f"Expected [0, 2], got {result[0]}"
def test_greedy_search():
"""beam_size=1 should behave like greedy search."""
def next_token_fn(seq):
if len(seq) >= 3:
return [0.1, 0.1, 0.8] # generate stop after 2 tokens
return [0.1, 0.7, 0.2] # token 1 is best
result = beam_search([0], next_token_fn, max_token=5, beam_size=1, stop_word_id=2)
assert result[0] == [0, 1, 1, 2], f"Expected [0, 1, 1, 2], got {result[0]}"
def test_max_token_limit():
"""Generation stops at max_token even without stop word."""
def next_token_fn(seq):
return [0.8, 0.2, 0.0] # stop word never generated
result = beam_search([5], next_token_fn, max_token=3, beam_size=1, stop_word_id=2)
assert len(result[0]) == 4, f"Expected length 4 (1 input + 3 generated), got {len(result[0])}"
assert result[0] == [5, 0, 0, 0], f"Expected [5, 0, 0, 0], got {result[0]}"
def test_multiple_completed_sequences():
"""Should return multiple completed sequences when beam_size > 1."""
def next_token_fn(seq):
if len(seq) >= 2:
return [0.0, 0.0, 1.0] # force stop at step 2
return [0.3, 0.5, 0.2]
result = beam_search([0], next_token_fn, max_token=3, beam_size=2, stop_word_id=2)
assert len(result) >= 2, f"Expected at least 2 sequences, got {len(result)}"
assert result[0] == [0, 1, 2], f"Expected [0, 1, 2], got {result[0]}"
assert result[1] == [0, 0, 2], f"Expected [0, 0, 2], got {result[1]}"
def test_beam_search_advantage():
"""
Beam search finds a better completed sequence than greedy.
Path A (token 0): 0.4 → STOP 0.9 → score = log(0.4)+log(0.9) ≈ -1.02
Path B (token 1): 0.5 → STOP 0.3 → score = log(0.5)+log(0.3) ≈ -1.90
Greedy picks B first, missing the globally better A→STOP path.
"""
def next_token_fn(seq):
last = seq[-1]
if last == 5:
return [0.4, 0.5, 0.1]
elif last == 0:
return [0.05, 0.05, 0.9]
elif last == 1:
return [0.05, 0.65, 0.3]
return [0.0, 0.0, 1.0]
beam = beam_search([5], next_token_fn, max_token=2, beam_size=2, stop_word_id=2)
beam_completed = [s for s in beam if s[-1] == 2]
assert [5, 0, 2] in beam_completed, f"Expected [5, 0, 2] in results, got {beam_completed}"
# [5, 0, 2] should be ranked higher than [5, 1, 2]
if [5, 1, 2] in beam_completed:
assert beam_completed.index([5, 0, 2]) < beam_completed.index([5, 1, 2])
def test_sorted_by_score():
"""Results should be sorted by cumulative log-probability, highest first."""
def next_token_fn(seq):
if len(seq) == 1:
return [0.3, 0.5, 0.2]
return [0.0, 0.0, 1.0]
result = beam_search([0], next_token_fn, max_token=2, beam_size=3, stop_word_id=2)
# [0, 1, 2]: log(0.5) + log(1.0) = -0.69
# [0, 0, 2]: log(0.3) + log(1.0) = -1.20
# [0, 2]: log(0.2) = -1.61
assert result == [[0, 1, 2], [0, 0, 2], [0, 2]], f"Unexpected order: {result}"
# Run all tests
if __name__ == "__main__":
test_immediate_stop()
test_greedy_search()
test_max_token_limit()
test_multiple_completed_sequences()
test_beam_search_advantage()
test_sorted_by_score()
print("All tests passed!")Complexity Analysis
Time Complexity: O(T×B×V×log(B×V))O(T \times B \times V \times \log(B \times V))O(T×B×V×log(B×V))
-
TTT =
max_token(generation steps) -
BBB =
beam_size(beams expanded per step) -
VVV = vocabulary size (candidates per beam)
-
log(B×V)\log(B \times V)log(B×V) from sorting candidates
Space Complexity: O(B×T+C×T)O(B \times T + C \times T)O(B×T+C×T)
-
B×TB \times TB×T for active beam sequences
-
C×TC \times TC×T for completed sequences (where CCC = number of completed beams)
-
O(B×V)O(B \times V)O(B×V) transient space for candidates per step
Key Discussion Points in Interview
1. Log Probability vs. Raw Probability
Question: "Why not just multiply probabilities?"
Answer: Floating-point underflow. For a 100-token sequence with average token probability 0.1:
-
Raw: 0.1100=10−1000.1^{100} = 10^{-100}0.1100=10−100 → underflows to 0.0
-
Log: 100×log(0.1)=−230.3100 \times \log(0.1) = -230.3100×log(0.1)=−230.3 → perfectly representable
2. Length Normalization
Question: "Beam search tends to prefer shorter sequences. Why, and how do you fix it?"
Answer: Cumulative log-probability is always negative, so longer sequences accumulate more negative scores. Length-normalized scoring divides by sequence length:
scorenorm=1∣y∣α∑t=1∣y∣logp(yt∣y<t)\text{score}{\text{norm}} = \frac{1}{|y|^\alpha} \sum{t=1}^{|y|} \log p(y_t \mid y_{<t})scorenorm=∣y∣α1∑t=1∣y∣logp(yt∣y<t)
where α∈[0,1]\alpha \in [0, 1]α∈[0,1] controls normalization strength (α=0\alpha = 0α=0: no normalization, α=1\alpha = 1α=1: full normalization). Google's NMT paper uses α=0.6\alpha = 0.6α=0.6.
3. Heap Optimization for Pruning
Question: "How would you optimize the candidate selection step?"
Answer: Instead of sorting all B×VB \times VB×V candidates, use a min-heap of size BBB:
import heapq
# O(B*V * log(B)) instead of O(B*V * log(B*V))
top_beams = heapq.nlargest(beam_size, all_candidates, key=lambda x: x[1])For large vocabularies (V≫BV \gg BV≫B), this is significantly faster.
4. Diverse Beam Search
Question: "All beams tend to produce very similar sequences. How do you encourage diversity?"
Answer:
-
Group-based beam search: Divide beams into groups, add dissimilarity penalty between groups
-
Top-k / Top-p sampling before beam search: Restrict candidate tokens to reduce redundancy
-
Hamming diversity penalty: Penalize beams that share tokens at the same position
5. Beam Search in Production (LLM Inference)
Question: "How does beam search interact with KV cache in transformer inference?"
Answer:
-
Each beam maintains its own KV cache (key-value pairs from attention layers)
-
When a beam is pruned, its cache is discarded
-
When a beam is kept, its cache is reused for the next step
-
Memory usage: O(B×L×H)O(B \times L \times H)O(B×L×H) where LLL = sequence length, HHH = hidden dimension
-
This is why large beam sizes are expensive in transformer inference