Part Seven · Chapter 22 · 3 of 3

Training &
Inference

How random weights become a language model — cross-entropy loss, backpropagation, the training loop, sampling strategies, and how a next-token predictor becomes a conversational assistant.

Contents
§01

What training means

At the end of Part 1, the forward pass produced a probability distribution over the vocabulary. The weights that produced it were random — the predictions were meaningless. Training is the process of adjusting those weights so the predictions become good.

"Good" is defined precisely: assign high probability to the token that actually comes next in real text. The training corpus is billions of sentences. For each one, the model predicts the next token at every position; we measure how wrong it was; we compute how each weight contributed to that wrongness; we nudge every weight slightly in the direction that reduces it. Repeat hundreds of billions of times.

Three questions to answer before diving in:

Training loop overview
§02

Tokenized datasets

Raw text is tokenized once and stored as a flat array of integers. Training draws chunks from this array — contiguous runs of context_length + 1 tokens. The first context_length tokens form the input; the last context_length tokens form the targets, shifted by one position.

This shift-by-one trick means every chunk of text yields context_length training examples simultaneously. A context window of 2,048 tokens gives 2,048 next-token prediction problems from a single forward pass.

dataset.js — chunking tokens into training examples JS
// How raw token IDs become training pairs (input, target).
// The target is just the input shifted one position to the right.

function makeDataset(tokenIds, contextLen) {
  const examples = [];
  for (let i = 0; i + contextLen < tokenIds.length; i++) {
    const input  = tokenIds.slice(i, i + contextLen);
    const target = tokenIds.slice(i + 1, i + contextLen + 1);
    examples.push({ input, target });
  }
  return examples;
}

// Mini corpus: "the cat sat on the mat"
// Vocabulary: {'the':0, 'cat':1, 'sat':2, 'on':3, 'mat':4}
const VOCAB = ['the', 'cat', 'sat', 'on', 'mat'];
const corpus = [0, 1, 2, 3, 0, 4];  // token IDs

const contextLen = 3;
const dataset = makeDataset(corpus, contextLen);

console.log(`Corpus: [${corpus.map(id => '"'+VOCAB[id]+'"').join(', ')}]`);
console.log(`Context length: ${contextLen}`);
console.log(`Number of training examples: ${dataset.length}\n`);

dataset.forEach(({input, target}, i) => {
  const iStr = input .map(id => '"'+VOCAB[id]+'"').join(', ');
  const tStr = target.map(id => '"'+VOCAB[id]+'"').join(', ');
  console.log(`Example ${i+1}:`);
  console.log(`  input:  [${iStr}]`);
  console.log(`  target: [${tStr}]`);
  console.log(`  → at position 0, predict "${VOCAB[target[0]]}"; at position 1, predict "${VOCAB[target[1]]}"; ...`);
});

In practice, data is batched: many examples are processed in parallel. A batch of 512 examples each of length 2,048 means ~1 million token predictions per forward pass. This is why training requires thousands of GPUs running for weeks.

§03

Cross-entropy loss

We need a number that measures how bad a prediction is. The right number is cross-entropy loss: given a probability distribution over the vocabulary and the true next token, the loss is the negative log probability assigned to the correct token:

L = −log P(correct token)

Why negative log? Because probabilities are between 0 and 1, so their logs are negative — we negate to get a positive loss. A perfect prediction (P = 1.0) gives loss 0. Assigning probability 0.01 to the correct token gives loss 4.6. Predicting completely wrong (P → 0) gives loss → ∞.

The loss across a batch is the mean over all token positions. This single scalar is what the entire training machinery works to minimize.

Interactive — cross-entropy loss

Drag the slider to see how loss changes as the model assigns different probabilities to the correct token.

Loss = 0.693
loss.js — cross-entropy loss from scratch JS
// Cross-entropy loss — the single number training minimizes.

function softmax(logits) {
  const max = Math.max(...logits);
  const ex  = logits.map(x => Math.exp(x - max));
  const sum = ex.reduce((a,b) => a+b, 0);
  return ex.map(x => x/sum);
}

function crossEntropyLoss(logits, targetId) {
  // logits: raw scores from the model [vocab_size]
  // targetId: index of the correct token
  const probs = softmax(logits);
  const correctProb = probs[targetId];
  // Clamp to avoid log(0)
  return -Math.log(Math.max(correctProb, 1e-10));
}

function batchLoss(allLogits, targetIds) {
  // allLogits: [seq_len × vocab_size], targetIds: [seq_len]
  const losses = allLogits.map((logits, i) =>
    crossEntropyLoss(logits, targetIds[i])
  );
  return losses.reduce((a,b) => a+b, 0) / losses.length;
}

// ── Example: vocabulary of 5 tokens ─────────────────────────────
const VOCAB = ['the', 'cat', 'sat', 'on', 'mat'];
const correctId = 2; // "sat"

// Case 1: uniform logits (untrained model)
const uniformLogits = [0, 0, 0, 0, 0];
const l1 = crossEntropyLoss(uniformLogits, correctId);
console.log('=== Untrained model (uniform logits) ===');
console.log(`Logits:  [${uniformLogits.join(', ')}]`);
console.log(`Probs:   [${softmax(uniformLogits).map(p=>p.toFixed(3)).join(', ')}]`);
console.log(`P(sat):  ${softmax(uniformLogits)[correctId].toFixed(3)}`);
console.log(`Loss:    ${l1.toFixed(4)}  (= log(5) — random-guess baseline)\n`);

// Case 2: model is starting to learn
const learningLogits = [-1, 0.5, 2.0, -0.5, 0.3];
const l2 = crossEntropyLoss(learningLogits, correctId);
console.log('=== Partially trained model ===');
console.log(`Logits:  [${learningLogits.join(', ')}]`);
console.log(`Probs:   [${softmax(learningLogits).map(p=>p.toFixed(3)).join(', ')}]`);
console.log(`P(sat):  ${softmax(learningLogits)[correctId].toFixed(3)}`);
console.log(`Loss:    ${l2.toFixed(4)}\n`);

// Case 3: well-trained model
const goodLogits = [-5, -3, 8, -4, -2];
const l3 = crossEntropyLoss(goodLogits, correctId);
console.log('=== Well-trained model ===');
console.log(`Logits:  [${goodLogits.join(', ')}]`);
console.log(`Probs:   [${softmax(goodLogits).map(p=>p.toFixed(3)).join(', ')}]`);
console.log(`P(sat):  ${softmax(goodLogits)[correctId].toFixed(3)}`);
console.log(`Loss:    ${l3.toFixed(4)}`);

The perplexity of a model is exp(average_loss) — it measures how "surprised" the model is by the text. An untrained model on a 50,000-token vocabulary has perplexity 50,000. GPT-2 achieves ~35 on standard benchmarks; GPT-4 is substantially lower.

§04

Backpropagation

We have a loss. Now we need to know: for each weight in the model, does increasing it make the loss go up or down, and by how much? This is the gradient — a vector with one number per weight, pointing in the direction of steepest loss increase.

Backpropagation computes this gradient efficiently using the chain rule. The computation graph of the forward pass — every addition, multiplication, softmax — is traversed in reverse. At each node, we compute how much that operation contributed to the final loss, and pass the signal back to its inputs.

A concrete walk-through

Consider the simplest possible network: one weight w, one input x, one output y = w·x, and a loss L = (y − target)². We want dL/dw:

backprop.js — autodiff from scratch JS
// Manual backpropagation through a tiny network.
// We implement forward pass, loss, and gradient computation by hand
// to make the chain rule fully explicit.

// Network: y = relu(w2 · relu(w1·x + b1) + b2)
// Loss: L = (y - target)^2

function relu(x) { return Math.max(0, x); }
function reluGrad(x) { return x > 0 ? 1 : 0; }

function forward(x, w1, b1, w2, b2) {
  const z1 = w1 * x + b1;           // pre-activation layer 1
  const a1 = relu(z1);               // activation
  const z2 = w2 * a1 + b2;          // pre-activation layer 2
  const y  = relu(z2);               // output
  return { z1, a1, z2, y };
}

function loss(y, target) {
  return (y - target) ** 2;
}

function backward(x, w1, b1, w2, b2, target, cache) {
  const { z1, a1, z2, y } = cache;

  // dL/dy — gradient of loss w.r.t. output
  const dL_dy = 2 * (y - target);

  // dL/dz2 — chain through relu at z2
  const dL_dz2 = dL_dy * reluGrad(z2);

  // dL/dw2 = dL/dz2 · dz2/dw2 = dL/dz2 · a1
  const dL_dw2 = dL_dz2 * a1;
  // dL/db2 = dL/dz2 · 1
  const dL_db2 = dL_dz2;

  // dL/da1 = dL/dz2 · dz2/da1 = dL/dz2 · w2
  const dL_da1 = dL_dz2 * w2;

  // dL/dz1 — chain through relu at z1
  const dL_dz1 = dL_da1 * reluGrad(z1);

  // dL/dw1 = dL/dz1 · x
  const dL_dw1 = dL_dz1 * x;
  // dL/db1 = dL/dz1 · 1
  const dL_db1 = dL_dz1;

  return { dL_dw1, dL_db1, dL_dw2, dL_db2 };
}

// ── Verify by finite difference (numerical gradient check) ───────
function numericalGrad(fn, param, eps=1e-5) {
  return (fn(param + eps) - fn(param - eps)) / (2 * eps);
}

let w1=0.5, b1=0.1, w2=0.8, b2=-0.2;
const x=2.0, target=3.0;

const cache = forward(x, w1, b1, w2, b2);
const L = loss(cache.y, target);
const grads = backward(x, w1, b1, w2, b2, target, cache);

console.log('Forward pass:');
console.log(`  x=${x}, target=${target}`);
console.log(`  z1=${cache.z1.toFixed(4)}, a1=${cache.a1.toFixed(4)}`);
console.log(`  z2=${cache.z2.toFixed(4)}, y=${cache.y.toFixed(4)}`);
console.log(`  Loss L=${L.toFixed(6)}\n`);

console.log('Analytical gradients (backprop):');
console.log(`  dL/dw1 = ${grads.dL_dw1.toFixed(6)}`);
console.log(`  dL/dw2 = ${grads.dL_dw2.toFixed(6)}\n`);

// Verify with numerical gradients
const ng_w1 = numericalGrad(w => {
  const c = forward(x, w, b1, w2, b2);
  return loss(c.y, target);
}, w1);
const ng_w2 = numericalGrad(w => {
  const c = forward(x, w1, b1, w, b2);
  return loss(c.y, target);
}, w2);

console.log('Numerical gradients (finite difference check):');
console.log(`  dL/dw1 ≈ ${ng_w1.toFixed(6)}  ← should match above`);
console.log(`  dL/dw2 ≈ ${ng_w2.toFixed(6)}  ← should match above`);
console.log(`\nMax error: ${Math.max(Math.abs(grads.dL_dw1 - ng_w1), Math.abs(grads.dL_dw2 - ng_w2)).toExponential(2)}`);

In a real Transformer, the same chain rule applies — but through billions of parameters and dozens of operations. PyTorch and JAX automate this by recording the computation graph during the forward pass, then traversing it in reverse. The math is identical to what you just ran; the scale is not.

Why residual connections matter for backprop: in a deep network without residuals, gradients are multiplied by the derivative at every layer on the way back. If those derivatives are slightly less than 1, the gradient shrinks exponentially — the vanishing gradient problem. Residual connections add a direct path (gradient = 1) that bypasses the multiplication chain.
§05

Gradient descent & Adam

Once we have the gradient, we update each weight by subtracting a small multiple of it. The multiplier is the learning rate — perhaps the most important hyperparameter in training:

w ← w − η · (dL/dw)

Vanilla gradient descent has a problem: the gradient is noisy (computed on a small batch) and the same learning rate applies to every parameter regardless of its history. The Adam optimizer fixes both problems by tracking a running average of the gradient and of the squared gradient for each parameter, using them to produce an adaptive, low-variance update.

Learning rate:
adam.js — Adam optimizer from scratch JS
// Adam optimizer — the standard for training LLMs.
// Maintains per-parameter moving averages of gradient (m) and
// gradient-squared (v), then computes bias-corrected updates.

function adamOptimizer(lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8) {
  const state = {};  // { paramName: { m, v, t } }

  return function update(params, grads) {
    for (const [name, grad] of Object.entries(grads)) {
      if (!state[name]) state[name] = { m: 0, v: 0, t: 0 };
      const s = state[name];
      s.t += 1;

      // Update biased first moment estimate
      s.m = beta1 * s.m + (1 - beta1) * grad;

      // Update biased second moment estimate
      s.v = beta2 * s.v + (1 - beta2) * grad * grad;

      // Bias correction (important in early steps when m,v are near zero)
      const m_hat = s.m / (1 - Math.pow(beta1, s.t));
      const v_hat = s.v / (1 - Math.pow(beta2, s.t));

      // Update parameter
      params[name] -= lr * m_hat / (Math.sqrt(v_hat) + eps);
    }
    return params;
  };
}

// ── Demo: train a single weight to match a target ─────────────────
function runComparison(optimizerFn, steps, name) {
  let w = 3.0;          // starting weight (far from optimal 0.0)
  const target = 0.0;
  const losses = [];

  const update = optimizerFn();
  const params = { w };

  for (let i = 0; i < steps; i++) {
    // MSE loss: L = (w - target)^2, dL/dw = 2*(w - target)
    const loss = (params.w - target) ** 2;
    const grad = 2 * (params.w - target);
    update(params, { w: grad });
    losses.push(loss);
  }

  console.log(`\n${name}:`);
  [0, 4, 9, 19, 49].filter(i => i < steps).forEach(i => {
    console.log(`  step ${String(i+1).padStart(3)}: w=${params.w.toFixed(4).padStart(8)}, loss=${losses[i].toFixed(6)}`);
  });
}

// SGD (no momentum)
function sgd(lr=0.1) {
  return () => (params, grads) => {
    for (const [k, g] of Object.entries(grads)) params[k] -= lr * g;
    return params;
  };
}

runComparison(sgd(0.1),    50, 'SGD (lr=0.1)');
runComparison(() => adamOptimizer(0.1), 50, 'Adam (lr=0.1)');

Real LLM training uses a learning rate schedule: the rate is warmed up linearly from zero over the first few thousand steps (to avoid large updates with unreliable initial gradients), then decayed following a cosine curve through the rest of training. The peak learning rate for GPT-3 was 6×10⁻⁴ — a tiny number applied to 175 billion parameters at every step.

§06

The training loop

We now have all pieces. The training loop is the engine that runs forward pass → loss → backward pass → optimizer update, repeatedly, over the entire dataset for multiple passes (epochs). For LLMs trained once on the internet, there is typically only one epoch — the dataset is so large that seeing each document more than once risks overfitting.

Production training loops add two more steps between backprop and the optimizer update. First, gradient clipping: if the global norm of all gradients exceeds a threshold (typically 1.0), every gradient is scaled down proportionally. This directly addresses the exploding gradient problem from Part 0 §06 — even with LSTMs or Transformers, anomalous batches can produce sudden large gradients, and clipping prevents them from causing catastrophic weight updates. Second, a learning rate schedule: the rate warms up from zero over the first few thousand steps (when initial random weights produce unreliable gradient directions), then decays through the rest of training.

train.js — a complete training loop on a tiny model JS
// Complete training loop — a tiny bigram neural network trained
// end-to-end: forward pass → cross-entropy loss → backprop → Adam.
// The same loop structure runs in PyTorch for GPT-4.

// ── Model: single embedding matrix, no transformer layers ────────
// This is the simplest possible LM: predict next token from current token.
// Output = softmax(E[input] · Eᵀ) — weight-tied embedding/unembedding.

const VOCAB2 = ['the','cat','sat','on','mat'];
const V = VOCAB2.length;  // 5
const D2 = 4;             // embedding dimension

// Initialize embedding table with small random values
function randn(r, c, scale=0.1) {
  return Array.from({length:r}, () =>
    Array.from({length:c}, () => {
      const u=1-Math.random(), v=Math.random();
      return Math.sqrt(-2*Math.log(u)) * Math.cos(2*Math.PI*v) * scale;
    })
  );
}

let E = randn(V, D2);  // [V × D] — shared embedding/unembedding

function forward2(inputId) {
  const emb   = E[inputId];                           // [D]
  // logits = emb · Eᵀ — dot with every row of E
  const logits = E.map(row => row.reduce((s,v,i) => s + v*emb[i], 0)); // [V]
  const max    = Math.max(...logits);
  const exps   = logits.map(x => Math.exp(x - max));
  const sum    = exps.reduce((a,b) => a+b, 0);
  const probs  = exps.map(x => x/sum);
  return { emb, logits, probs };
}

// Cross-entropy loss + gradients w.r.t. E
function lossAndGrad(inputId, targetId) {
  const { emb, probs } = forward2(inputId);
  const loss = -Math.log(Math.max(probs[targetId], 1e-10));

  // dL/d_logits[j] = probs[j] - 1{j == target}
  const dLogits = probs.map((p, j) => p - (j === targetId ? 1 : 0));

  // dL/d_E: each logit[j] = emb · E[j], so:
  //   dL/d_E[j] += dLogits[j] · emb        (outer product, j-th row)
  //   dL/d_E[inputId] += sum_j(dLogits[j] · E[j])   (chain from emb)
  const dE = Array.from({length:V}, () => new Array(D2).fill(0));

  for (let j = 0; j < V; j++)
    for (let d = 0; d < D2; d++)
      dE[j][d] += dLogits[j] * emb[d];

  for (let j = 0; j < V; j++)
    for (let d = 0; d < D2; d++)
      dE[inputId][d] += dLogits[j] * E[j][d];

  return { loss, dE };
}

// ── Adam state for E ──────────────────────────────────────────────
const mE = Array.from({length:V}, () => new Array(D2).fill(0));
const vE = Array.from({length:V}, () => new Array(D2).fill(0));
let   t  = 0;
const lr=0.05, beta1=0.9, beta2=0.999, eps=1e-8;

// ── Training data: bigram pairs from corpus ───────────────────────
const corpus2 = [0,1,2,3,0,4,0,1,3,0,4];  // "the cat sat on the mat..."
const pairs = [];
for (let i = 0; i < corpus2.length - 1; i++)
  pairs.push([corpus2[i], corpus2[i+1]]);

// ── Training loop ─────────────────────────────────────────────────
const steps = 200;
const logEvery = 40;
for (let step = 0; step < steps; step++) {
  t++;
  // Average gradients over all pairs
  const totalDE = Array.from({length:V}, () => new Array(D2).fill(0));
  let totalLoss = 0;

  for (const [inp, tgt] of pairs) {
    const { loss, dE } = lossAndGrad(inp, tgt);
    totalLoss += loss;
    for (let i=0;i<V;i++) for (let d=0;d<D2;d++) totalDE[i][d] += dE[i][d]/pairs.length;
  }
  totalLoss /= pairs.length;

  // Adam update on E
  for (let i=0;i<V;i++) for (let d=0;d<D2;d++) {
    const g = totalDE[i][d];
    mE[i][d] = beta1*mE[i][d] + (1-beta1)*g;
    vE[i][d] = beta2*vE[i][d] + (1-beta2)*g*g;
    const mh = mE[i][d]/(1-Math.pow(beta1,t));
    const vh = vE[i][d]/(1-Math.pow(beta2,t));
    E[i][d] -= lr * mh / (Math.sqrt(vh)+eps);
  }

  if (step % logEvery === 0 || step === steps-1)
    console.log(`step ${String(step+1).padStart(3)}: loss=${totalLoss.toFixed(4)}, perplexity=${Math.exp(totalLoss).toFixed(2)}`);
}

// ── Show what the model learned ───────────────────────────────────
console.log('\nLearned predictions (top-2 next token after each word):');
VOCAB2.forEach((word, id) => {
  const { probs } = forward2(id);
  const top2 = probs.map((p,i)=>[VOCAB2[i],p]).sort((a,b)=>b[1]-a[1]).slice(0,2);
  console.log(`  after "${word}": ${top2.map(([w,p])=>`"${w}"(${(p*100).toFixed(0)}%)`).join(', ')}`);
});

Notice that loss drops and perplexity falls toward the theoretical minimum for this small corpus. The model learns which tokens tend to follow which — the same information captured, at scale, by a real language model trained on the internet.

§07

Overfitting & generalization

A model that has memorized its training data perfectly will perform poorly on new text — it has learned the idiosyncrasies of training examples rather than the underlying structure of language. This is overfitting.

The standard diagnostic is to plot training loss and validation loss separately throughout training. Validation loss is computed on a held-out set not used for updates. Healthy training: both curves fall together. Overfitting: training loss keeps falling while validation loss levels off or rises.

Training scenario:

LLMs avoid overfitting primarily through sheer scale: training data is so vast that the model rarely sees the same document twice. Additional techniques include dropout (randomly zeroing activations during training), weight decay (an L2 penalty that keeps weights small), and careful early stopping based on validation loss.

TechniqueWhat it doesIntuition
DropoutZero out random activations with probability pForces the network not to rely on any single path
Weight decayAdd λ·‖w‖² to loss at each stepPenalizes large weights; shrinks toward zero
Early stoppingHalt when validation loss stops improvingThe validation curve tells you when to stop
Data augmentationVaried rephrasings of training examplesMore unique examples per concept
§08

Autoregressive generation

At inference time, the model has no target token — it must generate. The process is autoregressive: the model produces one token at a time, appends it to the context, then runs the forward pass again to predict the next one.

This is the most computationally expensive characteristic of LLM inference. Generating 100 tokens requires 100 sequential forward passes. Each pass processes the full context window, which grows by one token at each step. There is no way to parallelize across time steps — each token depends on all previous ones.

generate.js — autoregressive generation loop JS
// Autoregressive generation: produce one token at a time,
// feeding each output back as input for the next step.

const VOCAB3 = ['the','cat','sat','on','mat','dog','ran','a','big'];
const V3 = VOCAB3.length;

// Toy transition matrix (like a trained model's implicit knowledge)
// entry[i][j] = log probability of token j following token i
const logProbs = [
  // the   cat   sat   on    mat   dog   ran   a     big
  [-2.2, -1.2, -3.0, -3.0, -1.8, -1.5, -3.0, -1.3, -2.0],  // the
  [-3.0, -3.5, -1.0, -3.0, -3.0, -3.5, -3.5, -3.5, -3.5],  // cat
  [-3.5, -3.5, -3.5, -1.2, -3.5, -3.5, -3.5, -3.5, -3.5],  // sat
  [-1.0, -3.5, -3.5, -3.5, -1.5, -3.5, -3.5, -1.2, -3.5],  // on
  [-3.5, -3.5, -3.5, -3.5, -3.5, -3.5, -3.5, -3.5, -3.5],  // mat
  [-3.0, -3.5, -1.5, -3.0, -3.5, -3.5, -1.0, -3.5, -3.5],  // dog
  [-2.0, -3.5, -3.5, -1.5, -3.5, -3.5, -3.5, -3.5, -3.5],  // ran
  [-1.2, -1.3, -3.5, -3.5, -1.5, -1.3, -3.5, -3.5, -1.2],  // a
  [-3.5, -1.2, -3.5, -3.5, -3.5, -1.3, -3.5, -3.5, -3.5],  // big
];

function getLogits(contextTokens) {
  // Use last token for predictions (bigram model)
  const lastId = contextTokens[contextTokens.length - 1];
  return logProbs[lastId];
}

function softmax2(logits) {
  const max = Math.max(...logits);
  const ex  = logits.map(x => Math.exp(x - max));
  const sum = ex.reduce((a,b) => a+b, 0);
  return ex.map(x => x/sum);
}

function sampleFromProbs(probs) {
  let r = Math.random(), cumul = 0;
  for (let i = 0; i < probs.length; i++) {
    cumul += probs[i];
    if (r < cumul) return i;
  }
  return probs.length - 1;
}

function generate2(prompt, maxTokens=8) {
  const context = prompt.map(w => VOCAB3.indexOf(w));
  const generated = [...context];

  console.log(`Prompt: "${prompt.join(' ')}"`);
  console.log('Generating...\n');

  for (let step = 0; step < maxTokens; step++) {
    const logits = getLogits(generated);
    const probs  = softmax2(logits);

    // Show top-3 candidates at this step
    const top3 = probs.map((p,i)=>[VOCAB3[i],p])
      .sort((a,b)=>b[1]-a[1]).slice(0,3);
    const nextId = sampleFromProbs(probs);
    const nextTok = VOCAB3[nextId];

    console.log(`Step ${step+1}: top candidates: ${top3.map(([w,p])=>`"${w}"(${(p*100).toFixed(0)}%)`).join(' | ')}`);
    console.log(`         → sampled: "${nextTok}"`);
    generated.push(nextId);

    // Stop at sentence-ending tokens
    if (nextTok === 'mat') break;
  }

  const result = generated.map(id => VOCAB3[id]).join(' ');
  console.log(`\nFull output: "${result}"`);
}

generate2(['the']);

Every call to getLogits above corresponds to one full forward pass through a real Transformer. The fact that each step is sequential — and that context grows by one each time — is the fundamental bottleneck of LLM serving.

§09

Sampling strategies

Given a probability distribution over the vocabulary, how do you pick the next token? The choice matters enormously for the quality of generated text.

Greedy decoding

Always pick the highest-probability token. Deterministic, but produces repetitive, bland text. The model gets stuck in loops ("the the the the...").

Temperature

Before taking softmax, divide all logits by a temperature parameter T. High T (>1) flattens the distribution — more random. Low T (<1) sharpens it — more confident, less creative. T=1 is unmodified. T→0 approaches greedy.

Top-k sampling

Zero out all but the top-k tokens by probability, renormalize, then sample. Prevents the model from ever choosing very low-probability tokens. Typical k = 40–200.

Top-p (nucleus) sampling

Instead of a fixed k, keep the smallest set of tokens whose cumulative probability exceeds p. If the distribution is very peaked (model is confident), only a few tokens pass; if flat (model is uncertain), more do. Typical p = 0.9–0.95.

Interactive — sampling strategies
sampling.js — temperature, top-k, top-p JS
// All three sampling strategies from scratch.

function softmaxWithTemp(logits, temp=1.0) {
  const scaled = logits.map(x => x / temp);
  const max    = Math.max(...scaled);
  const ex     = scaled.map(x => Math.exp(x - max));
  const sum    = ex.reduce((a,b) => a+b, 0);
  return ex.map(x => x/sum);
}

function topKFilter(probs, k) {
  if (k <= 0) return probs;
  const indexed = probs.map((p,i) => [p,i]).sort((a,b) => b[0]-a[0]);
  const mask    = new Array(probs.length).fill(0);
  indexed.slice(0, k).forEach(([,i]) => mask[i] = 1);
  const filtered = probs.map((p,i) => p * mask[i]);
  const sum = filtered.reduce((a,b) => a+b, 0);
  return filtered.map(p => p/sum);
}

function topPFilter(probs, p) {
  if (p >= 1.0) return probs;
  const indexed = [...probs.map((v,i)=>[v,i])].sort((a,b)=>b[0]-a[0]);
  let cumul = 0, cutoff = 0;
  for (const [v,] of indexed) {
    cumul += v;
    cutoff++;
    if (cumul >= p) break;
  }
  const topTokens = new Set(indexed.slice(0, cutoff).map(([,i])=>i));
  const filtered  = probs.map((v,i) => topTokens.has(i) ? v : 0);
  const sum = filtered.reduce((a,b) => a+b, 0);
  return filtered.map(v => v/sum);
}

function sample2(probs) {
  let r = Math.random(), cumul = 0;
  for (let i=0;i<probs.length;i++) { cumul+=probs[i]; if(r<cumul) return i; }
  return probs.length-1;
}

// Example: model logits for 8 candidate tokens
const CANDS = ['the','cat','sat','ran','mat','dog','on','a'];
const rawLogits = [1.2, 2.8, 0.3, -0.5, 1.0, 2.1, 0.8, 1.5];

function showSampling(temp, k, p) {
  let probs = softmaxWithTemp(rawLogits, temp);
  if (k > 0) probs = topKFilter(probs, k);
  if (p < 1) probs = topPFilter(probs, p);
  const bar = v => '█'.repeat(Math.round(v*30));
  console.log(`\n--- temp=${temp}, top-k=${k||'off'}, top-p=${p} ---`);
  probs.map((p,i)=>[CANDS[i],p]).sort((a,b)=>b[1]-a[1])
    .forEach(([w,p]) => {
      if (p < 0.001) return;
      console.log(`  ${w.padEnd(5)} ${(p*100).toFixed(1).padStart(5)}%  ${bar(p)}`);
    });
}

showSampling(1.0,  0,   1.0);   // baseline
showSampling(0.5,  0,   1.0);   // low temperature (sharper)
showSampling(2.0,  0,   1.0);   // high temperature (flatter)
showSampling(1.0,  3,   1.0);   // top-k=3
showSampling(1.0,  0,   0.9);   // top-p=0.9

Production systems (Claude, GPT-4) combine all three: a moderate temperature, top-p sampling, and sometimes top-k, tuned per use case. Creative writing uses higher temperature; code generation and factual Q&A use lower.

§10

The KV cache

In the autoregressive loop, the model re-processes the entire growing context at every step. For a 1,000-token context, step 1,001 computes attention over 1,001 tokens, step 1,002 over 1,002 tokens, and so on. Without optimization this is O(n²) work for n tokens generated.

The KV cache eliminates almost all of this redundancy. Recall that in attention, the Key and Value matrices for each token are computed from that token's embedding — which never changes once a token is in the context. So we compute them once and cache them. Subsequent steps only compute Q, K, V for the new token, then attend over the cached K and V for all previous tokens.

Step through generation: Step 0
kvcache.js — cached vs uncached attention JS
// KV cache: store K and V for all past tokens, compute only the new step.
// Compare operation counts: cached vs naive.

function naiveAttentionCost(seqLen) {
  // Each step: recompute K,V for ALL tokens (seqLen multiplications each)
  // Total for generating n tokens from prompt of length p:
  //   sum_{i=p}^{p+n} i ≈ O(n² + pn)
  return seqLen * seqLen;  // for one step at seqLen
}

function cachedAttentionCost(newPos) {
  // Only compute Q, K, V for the NEW token (constant work)
  // Then attend over newPos cached K,V pairs
  const computeNewQKV = 1;         // constant
  const attendOverCache = newPos;  // linear in sequence length
  return computeNewQKV + attendOverCache;
}

// ── Simulate generation cost ──────────────────────────────────────
const promptLen = 10;
const genTokens = 20;

console.log('Operation count comparison (prompt=10 tokens, generating 20):');
console.log('Step | Naive (full recompute) | Cached | Speedup');
console.log('-----|----------------------|--------|--------');

let naiveTotal = 0, cachedTotal = 0;
for (let step = 1; step <= genTokens; step++) {
  const seqLen = promptLen + step;
  const naive  = naiveAttentionCost(seqLen);
  const cached = cachedAttentionCost(seqLen);
  naiveTotal  += naive;
  cachedTotal += cached;
  if (step <= 5 || step === genTokens) {
    const speedup = (naive/cached).toFixed(1);
    console.log(`  ${String(step).padStart(2)} |   ${String(naive).padStart(6)} ops       |  ${String(cached).padStart(4)} ops | ${speedup}×`);
  } else if (step === 6) {
    console.log('  ...');
  }
}
console.log(`\nTotal: naive=${naiveTotal} ops, cached=${cachedTotal} ops`);
console.log(`Overall speedup: ${(naiveTotal/cachedTotal).toFixed(1)}×`);

// ── Memory cost of the KV cache ───────────────────────────────────
const layers=12, heads=12, dk=64, contextLen=4096;
const bytesPerFloat=2;  // fp16
const cacheSizeBytes = 2 * layers * heads * dk * contextLen * bytesPerFloat;
const cacheSizeMB = (cacheSizeBytes / 1e6).toFixed(1);
console.log(`\nKV cache memory (GPT-2 small, 4096 ctx, fp16): ${cacheSizeMB} MB`);
console.log('(2 = K+V, layers=12, heads=12, dk=64, ctx=4096, 2 bytes/val)');

The KV cache trades memory for compute — at long context lengths it can consume gigabytes of GPU memory. Systems serving many concurrent users must manage a pool of KV caches, evicting them for idle sessions (similar to how an OS manages physical memory pages). This is one of the central engineering challenges in LLM serving infrastructure.

§11

RLHF: from predictor to assistant

A pretrained LLM is a next-token predictor. Given "The capital of France is", it will output "Paris" because that is what follows in training data. But it will also complete "How do I pick a lock?" with detailed instructions, because those also appear on the internet. It has no concept of helpfulness, harmlessness, or honesty — only statistical continuation.

Reinforcement Learning from Human Feedback (RLHF) is the process that shapes a pretrained model into an assistant that follows instructions, declines harmful requests, and gives helpful answers. It is not one algorithm but a pipeline of three stages.

Stage 1: Supervised fine-tuning (SFT)

Human contractors write high-quality (prompt, response) pairs demonstrating desired behavior: helpful answers, refusals of harmful requests, honest admissions of uncertainty. The model is fine-tuned on these pairs using standard cross-entropy loss — same as pretraining, just on a small curated dataset (typically 10,000–100,000 examples). This produces a model that can follow instructions but inconsistently.

Stage 2: Reward model training

For each prompt, the SFT model generates several responses. Human raters rank them by quality. A separate model — the reward model — is trained to predict these human preference rankings. It takes (prompt, response) pairs as input and outputs a scalar reward score. The reward model learns what "good" looks like from human judgment.

Stage 3: RL fine-tuning (PPO)

The SFT model is now called the policy — in reinforcement learning, a policy is simply a function that maps a state (the conversation so far) to an action (the next token to generate). The policy generates responses; the reward model scores them. The policy weights are updated by PPO (Proximal Policy Optimization), a gradient-based algorithm that maximises expected reward while deliberately limiting how large any single update can be — it clips the gradient step if it would change the policy's probabilities by more than a set ratio. This prevents catastrophic updates where one very good or very bad example sends the weights into a poor region. A KL-divergence penalty (KL divergence measures how different two probability distributions are — here, how much the updated policy has drifted from the original SFT model) is added to the loss to prevent "reward hacking": generating text that scores high on the reward model without actually being helpful.

RLHF pipeline overview
rlhf.js — reward model training (conceptual) JS
// Reward model training — the core of Stage 2.
// A reward model learns to predict human preference between two responses.
// The Bradley-Terry model: P(A preferred over B) = σ(r(A) - r(B))

function sigmoid(x) { return 1 / (1 + Math.exp(-x)); }

// Tiny reward model: scalar dot product with a feature vector
// In reality this is a full Transformer; here it's a linear model
// for illustrative purposes.
function rewardModel(response, weights) {
  // Features: [length_score, clarity_score, relevance_score, safety_score]
  const features = extractFeatures(response);
  return features.reduce((s, f, i) => s + f * weights[i], 0);
}

function extractFeatures(response) {
  // Simulated features — in a real system, these come from the Transformer
  return [
    Math.min(response.length / 100, 1.0),       // length (normalized)
    response.includes('because') ? 0.8 : 0.3,   // explanation present
    response.includes('?') ? 0.2 : 0.6,         // not a question back
    response.includes('harmful') ? 0.0 : 0.9,   // safety feature
  ];
}

// Bradley-Terry loss: maximize P(preferred > rejected)
function preferenceLoss(rPreferred, rRejected) {
  return -Math.log(sigmoid(rPreferred - rRejected) + 1e-10);
}

// Training data: (preferred response, rejected response) pairs
const preferences = [
  {
    preferred: "The capital of France is Paris, because it has been the political center since the 10th century.",
    rejected:  "Paris I think? Not sure.",
  },
  {
    preferred: "I cannot help with that request as it could cause harm.",
    rejected:  "Sure, here is how to do something harmful...",
  },
  {
    preferred: "Photosynthesis converts light energy into glucose because chlorophyll absorbs sunlight.",
    rejected:  "Plants make food?",
  },
];

// Initialize reward model weights
let weights = [0.5, 0.5, 0.5, 0.5];
const lr = 0.1;

console.log('Training reward model on human preference data:\n');
for (let epoch = 0; epoch < 30; epoch++) {
  let totalLoss = 0;
  const dW = new Array(4).fill(0);

  for (const { preferred, rejected } of preferences) {
    const rP = rewardModel(preferred, weights);
    const rR = rewardModel(rejected, weights);
    const loss = preferenceLoss(rP, rR);
    totalLoss += loss;

    // Gradient: dLoss/dW = -( 1 - σ(rP-rR) ) * (featP - featR)
    const grad_factor = -(1 - sigmoid(rP - rR));
    const featP = extractFeatures(preferred);
    const featR = extractFeatures(rejected);
    for (let i=0;i<4;i++) dW[i] += grad_factor * (featP[i] - featR[i]);
  }

  for (let i=0;i<4;i++) weights[i] -= lr * dW[i] / preferences.length;

  if (epoch % 10 === 0 || epoch === 29)
    console.log(`Epoch ${String(epoch+1).padStart(2)}: avg loss = ${(totalLoss/preferences.length).toFixed(4)}`);
}

console.log('\nLearned reward weights:', weights.map(w=>w.toFixed(3)));
const featureNames = ['length','explains','direct','safe'];
weights.forEach((w,i) => console.log(`  ${featureNames[i].padEnd(10)} → weight ${w.toFixed(3)}`));

// Score some responses
console.log('\nReward scores for sample responses:');
const tests = [
  "The answer is 42 because the universe requires it.",
  "I dunno",
  "I cannot help with anything harmful to users.",
];
tests.forEach(r => console.log(`  "${r.slice(0,45)}..." → score ${rewardModel(r,weights).toFixed(3)}`) );
Beyond RLHF: newer alignment approaches include Constitutional AI (Claude's training method), where a model critiques its own outputs against a set of principles rather than relying solely on human raters; and Direct Preference Optimization (DPO), which reformulates the Stage 3 objective mathematically so that the reward model is eliminated entirely — the policy is trained directly on human preference pairs (A is better than B) using a closed-form loss derived from the same KL-constrained objective that PPO optimises. The field is evolving rapidly.

After RLHF, the model has shifted from raw next-token predictor to something that follows instructions, acknowledges uncertainty, declines harmful requests, and produces helpful responses. The architecture and the weights from pretraining remain — only the fine-tuned weights and the direction of the probability mass have changed.


Series complete

You have now seen the full picture — from text to tokens, tokens to embeddings, embeddings through attention and feedforward layers into logits, random weights shaped by gradient descent and Adam into a language model, and finally RLHF nudging that model toward helpfulness. Everything from first principles.

Part 1: Attention & the Transformer


Where to go next

This series gave you the mechanics. These go deeper in complementary directions.

WhatWhy it matters
Attention Is All You Need — Vaswani et al., 2017 The original Transformer paper. Short, readable, and now that you have the vocabulary you can follow every equation.
Language Models are Few-Shot Learners — Brown et al., 2020 (GPT-3) Establishes the scaling-laws era: same architecture as Part 1, trained at enormous scale. Section 2 (model and architecture) is directly readable after this series.
makemore — Andrej Karpathy Builds a character-level language model from scratch in Python, step by step — the closest real-code equivalent of what this series built in JS. The accompanying YouTube lectures cover backprop through attention in detail.
nanoGPT — Andrej Karpathy A minimal, trainable GPT-2 implementation in ~300 lines of PyTorch. Read it after makemore — the causal mask, weight tying, and training loop you saw here all appear directly.
Training language models to follow instructions… — Ouyang et al., 2022 (InstructGPT) The RLHF paper that underlines Part 2 §11. Section 3 (method) maps directly onto the three stages described here.