GPT From Scratch - A Simple Primer¶

Hi! I am Srihari Unnikrishnan(@pythongiant) and this is a small primer on transformers, encoder-decoder transformers and generating text with a decoder-only transformer.

Send me an email at srihari[dot]unnikrishnan[at]gmail[dot]com if you have any questions or any suggestions/recommendations!
I would be more than glad to help you out.

Encoder - Decoder Transformer¶

An Encoder–Decoder Transformer is the original Transformer architecture introduced in Attention Is All You Need (Vaswani et al., 2017). It is designed to model conditional sequence generation.

Why encoder–decoder lost dominance¶

Encoder–decoder models were very strong around 2018–2020. Decoder-only models won out because they:

  • Scale better (simpler architecture, fewer attention paths)

  • Have perfect training–inference alignment

  • Can absorb conditional tasks via prompting

  • Are easier to use as general-purpose models

  • Enable efficient KV caching during generation

In practice, decoder-only models learned to simulate encoder–decoder behavior by treating the “input” as part of the prefix

Architecture¶

Arch

Positional Embeddings¶

Positional embeddings add explicit numerical information about the order (position) of elements in a sequence to their vector representations. This is necessary because Transformer models process all tokens in parallel using self-attention and therefore have no inherent notion of sequence order. Without positional information, a Transformer would treat a sentence as a set of tokens rather than an ordered sequence.

To solve this, each position in a sequence (e.g., 1st word, 2nd word, 3rd word, etc.) is assigned a position-specific vector. This positional vector is added to the token’s embedding before being passed into the model. As a result, the final input representation contains both:

  • What the token is (semantic content from the word embedding), and

  • Where the token appears (its position in the sequence).

This allows the model to distinguish between sequences with the same words but different orderings, such as “dog bites man” versus “man bites dog”, where meaning critically depends on position.

The example shows fixed, absolute sinusoidal positional embeddings as introduced in Attention Is All You Need.

# Position Encoding
import torch
import torch.nn as nn
import math


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()  # Initialize nn.Module internals

        # Create a (max_len x d_model) matrix to store positional encodings
        pe = torch.zeros(max_len, d_model)

        # Create position indices: [0, 1, 2, ..., max_len-1]
        # Shape: (max_len, 1)
        position = torch.arange(0, max_len).unsqueeze(1)

        # Compute the frequency scaling term for each even dimension
        # Shape: (d_model / 2,)
        #
        # Each pair of embedding dimensions (2i, 2i+1) shares a unique frequency.
        # These frequencies are exponentially spaced so that:
        #   - Early dimensions oscillate quickly (high-frequency signals),
        #     capturing fine-grained, local positional differences.
        #   - Later dimensions oscillate slowly (low-frequency signals),
        #     capturing long-range, global positional structure.
        #
        # This multi-scale design allows the model to represent both short- and
        # long-distance relationships between tokens.
        #
        # The exponential scaling (10000^(-2i / d_model)) ensures:
        #   - Positional encodings remain unique across long sequences
        #   - Relative position information can be recovered using linear operations,
        #     which is critical for attention mechanisms.
        #
        # Using fixed (non-learned) frequencies avoids overfitting to training
        # sequence lengths and enables generalization to longer sequences.

        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )

        # Apply sine to even embedding dimensions (0, 2, 4, ...)
        # Broadcasting: (max_len, 1) * (d_model/2,) -> (max_len, d_model/2)
        pe[:, 0::2] = torch.sin(position * div_term)
        # If positional encodings were different per batch element, the model could learn:
        # “this sentence is in batch slot 3”
        # “batch index correlates with label”
        # That would be spurious information and break generalization. Hence we ensure: No batch-specific signal and that only relative token order is encoded
        # Apply cosine to odd embedding dimensions (1, 3, 5, ...)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Add a batch dimension so it can be broadcast across batches
        # Final shape: (1, max_len, d_model)
        pe = pe.unsqueeze(0)

        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        Args:
            x: Input embeddings of shape (batch_size, seq_len, d_model)

        Returns:
            Embeddings with positional information added
        """
        # Slice positional encodings to match sequence length
        # Broadcasting adds the same positional encoding to each batch element
        return x + self.pe[:, :x.size(1)]

Importantly, positional embeddings depend only on the token’s position within a sequence, not on which sequence in the batch it belongs to. Therefore, the same positional embedding for position t is shared across all sequences in a batch and is broadcast during computation. This ensures that positional information encodes relative and absolute order, rather than introducing spurious batch-specific signals.

There are 2 types of positional embeddings

  • Fixed (sinusoidal), where positions are encoded using sine and cosine functions of different frequencies, enabling the model to generalize to longer sequences and infer relative distances between tokens.

  • Learned, where positional vectors are trained parameters similar to word embeddings.

In both cases, positional embeddings are combined with token embeddings via element-wise addition, preserving dimensionality while enriching each token representation with sequence order information. This simple design enables Transformers to model syntax, dependencies, and context over sequences effectively, despite their parallel processing nature.

Scaled Dot-Product Attention¶

Scaled dot-product attention is the core mechanism that allows Transformer models to selectively focus on the most relevant parts of a sequence when processing each token. Instead of treating all tokens equally, attention computes how strongly one token should “attend” to every other token based on learned representations.

Each token is represented using three vectors:

  • Query (Q): what this token is looking for
  • Key (K): what each token offers
  • Value (V): the information each token contains

Attention works by comparing each query with all keys using a dot product, producing a set of similarity scores that measure relevance. A higher score means the key token is more relevant to the query token. These scores are then normalized into a probability distribution and used to compute a weighted sum of the value vectors, producing a context-aware representation of each token.

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super().__init__()

        # Scaling factor = sqrt(d_k)
        #
        # The dot product Q · K grows in magnitude with the dimensionality d_k.
        # Without scaling, large dot products would push the softmax into
        # extremely peaked distributions, causing:
        #   - Vanishing gradients
        #   - Overconfident, brittle attention weights
        #
        # Dividing by sqrt(d_k) keeps the variance of the scores roughly constant,
        # stabilizing training and ensuring smoother attention distributions.
        self.scale = math.sqrt(d_k)

    def forward(self, Q, K, V, mask=None):
        """
        Args:
            Q: Queries of shape (..., seq_len_q, d_k)
            K: Keys of shape (..., seq_len_k, d_k)
            V: Values of shape (..., seq_len_k, d_v)
            mask: Optional attention mask (broadcastable to scores shape)
                  Used to block padding tokens or future tokens (causal masking)

        Returns:
            output: Attention-weighted values
            attn: Attention weight matrix
        """

        # Compute raw attention scores via dot product:
        #   score_{i,j} = Q_i · K_j
        #
        # K is transposed so that:
        #   (..., seq_len_q, d_k) @ (..., d_k, seq_len_k)
        # → (..., seq_len_q, seq_len_k)
        #
        # Each score measures how much token i should attend to token j.
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        # Apply mask (if provided) to prevent attention to certain positions:
        #   - Padding tokens (in encoder attention)
        #   - Future tokens (in decoder self-attention)
        #
        # Masked positions are set to a large negative value so that
        # softmax assigns them near-zero probability.
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Normalize scores into a probability distribution using softmax.
        # For each query token, attention weights over all key positions sum to 1.
        attn = torch.softmax(scores, dim=-1)

        # Compute the final attention output as a weighted sum of values:
        #   output_i = Σ_j attn_{i,j} · V_j
        #
        # This mixes information from all tokens, weighted by relevance.
        return torch.matmul(attn, V), attn

A critical design choice is the scaling factor. As the dimensionality of the key vectors increases, the magnitude of dot products grows, which would cause the softmax function to become overly peaked. This would make the model excessively confident in a small number of tokens and lead to vanishing gradients during training. Dividing by sqrt(d_K) stabilizes the distribution of attention scores, ensuring smooth gradients and more balanced attention across tokens.

Another key design element is masking, which allows the model to control where attention is permitted. Masks prevent attention to padding tokens (which carry no information) and, in autoregressive settings, prevent tokens from attending to future positions. This enforces the correct information flow while keeping the attention mechanism fully parallelizable.

By combining similarity-based scoring, variance-stabilizing scaling, and flexible masking, scaled dot-product attention enables Transformers to model long-range dependencies, contextual meaning, and structured information flow efficiently. Importantly, this mechanism is fully differentiable, parallelizable, and independent of sequence length assumptions, making it a foundational building block of modern sequence models.

MultiHeadAttention¶

Multi-head attention extends scaled dot-product attention by allowing the model to attend to different types of relationships simultaneously. Instead of computing a single attention distribution over the entire embedding space, the model splits the representation into multiple lower-dimensional subspaces called heads and performs attention independently in each one.

Each head learns to focus on a different aspect of the sequence. For example, one head may specialize in short-range syntactic relationships (such as subject–verb agreement), while another may focus on long-range semantic dependencies or positional structure. This decomposition allows the model to capture richer and more diverse patterns than a single attention mechanism could.

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()

        # d_model must be divisible by the number of heads
        # because we split the embedding space evenly across heads.
        # Each head operates on a lower-dimensional subspace.
        assert d_model % n_heads == 0

        # Dimensionality per head (d_k)
        # Total dimension is preserved when heads are concatenated back.
        self.d_k = d_model // n_heads
        self.n_heads = n_heads

        # Linear projections for Queries, Keys, and Values
        #
        # These allow the model to learn different representations
        # of the same input depending on how it is being used:
        #   - As a query (what am I looking for?)
        #   - As a key   (what do I offer?)
        #   - As a value (what information do I carry?)
        #
        # Each projects from d_model → d_model, after which we split
        # into multiple heads.
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # Output projection
        #
        # After attention is computed independently in each head,
        # their outputs are concatenated and mixed using this layer.
        # This allows information from different heads to interact.
        self.W_o = nn.Linear(d_model, d_model)

        # Scaled dot-product attention applied independently per head
        self.attn = ScaledDotProductAttention(self.d_k)

    def forward(self, Q, K, V, mask=None):
        """
        Args:
            Q, K, V: Input tensors of shape (batch_size, seq_len, d_model)
            mask: Optional attention mask

        Returns:
            Output tensor of shape (batch_size, seq_len, d_model)
        """

        B, T, _ = Q.shape

        # Project inputs into Q, K, V spaces
        # Shape after projection: (B, T, d_model)
        #
        # Then reshape to split into multiple heads:
        #   (B, T, n_heads, d_k)
        #
        # Finally transpose so that heads become a separate dimension:
        #   (B, n_heads, T, d_k)
        #
        # This allows attention to be computed independently per head.
        Q = self.W_q(Q).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        # Apply scaled dot-product attention to each head in parallel
        # Output shape: (B, n_heads, T, d_k)
        x, _ = self.attn(Q, K, V, mask)

        # Recombine heads:
        #   (B, n_heads, T, d_k) → (B, T, n_heads * d_k)
        #
        # .contiguous() ensures memory layout is correct before reshaping
        x = x.transpose(1, 2).contiguous().view(B, T, -1)

        # Final linear projection mixes information across heads
        return self.W_o(x)

To enable this, the input embeddings are first projected into separate query, key, and value spaces using learned linear transformations. These projections are then reshaped so that each head operates on a reduced dimensionality. Attention is computed independently within each head using scaled dot-product attention, producing multiple parallel context representations.

After attention is applied, the outputs of all heads are concatenated and passed through a final linear projection. This step is crucial: it allows information from different heads to be combined and re-integrated into a single representation, enabling interaction between the different relational perspectives learned by each head.

A key design choice is that splitting into multiple heads does not increase computational cost relative to a single large attention operation. The total dimensionality remains constant; it is simply redistributed across heads. This makes multi-head attention both expressive and efficient.

Overall, multi-head attention gives the Transformer the ability to look at the same sequence through multiple lenses at once, making it far more powerful at modeling complex structure, long-range dependencies, and contextual meaning.

FeedForward Network (Position-wise FFN)¶

1. Why do we need a FeedForward network at all?¶

Self-attention mixes information across tokens, but it is largely a linear operation with respect to feature dimensions. Without an additional non-linear component, stacking attention layers would still result in a model that is limited in its ability to learn complex transformations.

This motivation is stated explicitly in Vaswani et al., 2017 (Attention Is All You Need), where each attention sub-layer is followed by a position-wise fully connected feed-forward network to introduce non-linearity and increase representational power.

The FeedForward layer:

  • Operates independently on each token (position-wise)
  • Mixes information across features, not across time
  • Acts like a learned feature transformer at each position

You can think of attention as answering “where should I look?” and the feedforward network as answering “how should I process what I found?”

This division of labor—global interaction via attention, local transformation via FFN—is a core architectural principle of Transformers.

Reference: Vaswani et al., Attention Is All You Need, NeurIPS 2017

2. Why expand to d_ff and then project back?¶

This is a deliberate capacity expansion strategy, introduced in the original Transformer.

In Vaswani et al. (2017), the feedforward network is defined as:

[ \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 ]

where ( W_1 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}} ) and ( W_2 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}} ).

The reasoning:

  • Expanding to a higher-dimensional space (d_ff ≫ d_model) allows the model to represent richer intermediate features.
  • Projecting back to d_model keeps the interface between layers consistent, enabling deep stacking.

This mirrors a well-established pattern in deep learning:

temporary expansion → nonlinear processing → compression

Similar ideas appear in:

  • Bottleneck layers in CNNs (He et al., ResNet)
  • Inverted residuals in MobileNet
  • MLP blocks in Vision Transformers (Dosovitskiy et al., 2020)

Reference: Vaswani et al., 2017 Dosovitskiy et al., An Image is Worth 16×16 Words, ICLR 2021

3. Why residual connections everywhere?¶

Residual connections were introduced to address the optimization difficulty of deep networks (He et al., 2016) and were adopted wholesale in Transformers.

In the Transformer, every sub-layer (attention and feedforward) is wrapped with a residual connection:

[ x \rightarrow x + \text{sublayer}(x) ]

This ensures that:

  • The model can easily learn identity mappings
  • Gradients flow effectively through many stacked layers
  • Deeper models remain trainable and stable

Vaswani et al. explicitly note that residual connections are critical for training deep attention-based architectures.

Formally, residuals allow layers to learn incremental refinements rather than complete transformations, which significantly improves optimization.

Reference: He et al., Deep Residual Learning for Image Recognition, CVPR 2016 Vaswani et al., 2017

4. Why LayerNorm after each sub-layer?¶

Layer Normalization was chosen because:

  • It normalizes per token, not across the batch
  • It works well with variable sequence lengths
  • It is stable under parallel computation

In the original Transformer, LayerNorm is applied after the residual connection (post-norm):

x = LayerNorm(x + sublayer(x))

LayerNorm helps:

  • Reduce internal covariate shift
  • Prevent exploding or vanishing activations
  • Stabilize training across layers

Later work (e.g., Pre-LN Transformers) showed that moving LayerNorm before the sub-layer improves gradient flow in very deep models, but the core motivation remains unchanged.

References: Ba et al., Layer Normalization, arXiv 2016 Vaswani et al., 2017 Xiong et al., On Layer Normalization in Transformers, ICML 2020

5. Why split attention and feedforward into two sub-layers?¶

This separation is a deliberate inductive bias introduced in the original Transformer architecture.

Each Encoder block alternates between:

  1. Context aggregation (self-attention)
  2. Feature transformation (position-wise feedforward)

This enforces a clean conceptual separation:

  • Attention decides which tokens influence each other
  • Feedforward decides how each token representation should be transformed

Stacking these blocks allows:

  • Global information to propagate through attention
  • Local, non-linear refinement at each layer

This alternating structure is one of the key reasons Transformers scale so well with depth and data.

Reference: Vaswani et al., 2017

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()

        # Position-wise FeedForward Network
        #
        # This is applied independently to each token position.
        # It does NOT mix information across time (sequence length),
        # only across feature dimensions.
        #
        # The structure expands the representation into a higher-
        # dimensional space (d_ff), applies a non-linearity,
        # and then projects it back to d_model.
        #
        # This allows the model to:
        #   - Introduce non-linear transformations
        #   - Increase representational capacity
        #   - Learn complex feature interactions per token
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),  # Expansion (feature mixing)
            nn.ReLU(),                 # Non-linearity
            nn.Linear(d_ff, d_model)   # Compression back to model dimension
        )

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        # FeedForward is applied identically to every token position
        return self.net(x)


class EncoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()

        # Multi-head self-attention:
        # Allows each token to attend to all other tokens in the sequence
        # and build context-aware representations.
        self.attn = MultiHeadAttention(d_model, n_heads)

        # Position-wise feedforward network:
        # Adds non-linear transformation capacity after attention.
        self.ff = FeedForward(d_model, d_ff)

        # Layer normalization layers
        #
        # These stabilize training by normalizing feature distributions
        # and help gradient flow in deep transformer stacks.
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Dropout for regularization
        # Prevents overfitting and co-adaptation of neurons.
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
            mask: Optional attention mask

        Returns:
            Output tensor of shape (batch_size, seq_len, d_model)
        """

        # --- Self-Attention Sub-layer ---
        #
        # Residual connection:
        #   Preserves the original representation and improves gradient flow.
        #
        # Dropout:
        #   Regularizes attention outputs.
        #
        # LayerNorm (post-norm formulation here):
        #   Stabilizes activations and training dynamics.
        x = self.norm1(
            x + self.dropout(self.attn(x, x, x, mask))
        )

        # --- FeedForward Sub-layer ---
        #
        # Again, we apply:
        #   - Position-wise feedforward transformation
        #   - Dropout for regularization
        #   - Residual connection to preserve information
        #   - LayerNorm for stability
        x = self.norm2(
            x + self.dropout(self.ff(x))
        )

        return x

The Transformer encoder block combines ideas from residual learning, normalization, and MLP expansion into a modular design where attention handles interaction and feedforward layers handle transformation. This structure, first formalized in Attention Is All You Need, remains largely unchanged across modern large language models.

Decoder Block — Why it exists and how it works¶

1. Why does the decoder have masked self-attention?¶

Unlike the encoder, the decoder is used for autoregressive generation. When predicting token t, the model must not access tokens t+1, t+2, ….

This constraint is enforced using a causal (look-ahead) mask in self-attention, as introduced in Vaswani et al., 2017. The mask ensures that each position can only attend to earlier positions, preserving the correct probabilistic factorization.

Without masked self-attention, the decoder would leak future information and collapse training.

2. Why do we need cross-attention at all?¶

Cross-attention is what allows the decoder to condition on the source sequence.

  • Queries come from the decoder (current generation state)
  • Keys and values come from the encoder (encoded source information)

This mechanism allows the decoder to dynamically decide:

“Which parts of the input sequence are relevant right now?”

This design replaces fixed alignment mechanisms used in earlier sequence-to-sequence models (e.g., RNN encoder–decoders) with a fully differentiable, parallelizable alternative.

Reference: Bahdanau et al., Neural Machine Translation by Jointly Learning to Align and Translate, 2015 Vaswani et al., 2017

3. Why is self-attention applied before cross-attention?¶

This ordering is intentional and appears in the original Transformer.

The decoder first:

  1. Builds a representation of what has been generated so far (self-attention)
  2. Then aligns that representation with the source sequence (cross-attention)

Conceptually:

  • Self-attention answers: “What have I already said?”
  • Cross-attention answers: “What does the input tell me to say next?”

Reversing this order degrades performance, as the decoder would try to attend to the source without a coherent internal state.

4. Why does the decoder also need a FeedForward network?¶

Just like in the encoder, attention alone is insufficient.

  • Attention mixes information across tokens
  • FeedForward networks introduce non-linearity and feature transformation

In the decoder, the FFN refines the token representation after both:

  • past context (self-attention), and
  • source context (cross-attention)

This ensures expressive power at each decoding step.

5. Why three residual + normalization blocks?¶

Each sub-layer (self-attn, cross-attn, FFN) performs a fundamentally different operation and has different activation statistics. Wrapping each with its own residual connection and LayerNorm:

  • Stabilizes training
  • Preserves information flow
  • Allows deep stacking of decoder blocks

This mirrors the encoder design but extends it to handle conditional generation.

class DecoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()

        # Masked multi-head self-attention
        #
        # This allows each target token to attend only to:
        #   - itself
        #   - previous target tokens
        #
        # The causal (look-ahead) mask ensures autoregressive behavior:
        # the model cannot "see the future" during generation.
        self.self_attn = MultiHeadAttention(d_model, n_heads)

        # Cross-attention (encoder–decoder attention)
        #
        # Queries come from the decoder (what do I need?),
        # Keys and Values come from the encoder (what information is available?).
        #
        # This is how the decoder conditions its predictions on the source sequence.
        self.cross_attn = MultiHeadAttention(d_model, n_heads)

        # Position-wise feedforward network
        #
        # Applies non-linear transformation independently to each target token
        # after contextual information has been integrated.
        self.ff = FeedForward(d_model, d_ff)

        # LayerNorm layers for each sub-layer
        #
        # Separate norms are used because each sub-layer has
        # different statistical properties.
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, tgt_mask=None, src_mask=None):
        """
        Args:
            x: Decoder input embeddings (batch_size, tgt_len, d_model)
            enc_out: Encoder outputs (batch_size, src_len, d_model)
            tgt_mask: Causal mask to block future target tokens
            src_mask: Mask to block padding in the source sequence

        Returns:
            Decoder output representations
        """

        # Allows each target token to attend only to earlier target tokens.
        # This enforces the autoregressive property required for generation.
        #
        # Residual connection + dropout + LayerNorm stabilize training
        # and preserve the original token representation.
        x = self.norm1(
            x + self.dropout(self.self_attn(x, x, x, tgt_mask))
        )

        # Decoder queries attend over encoder keys and values.
        # This is the information bottleneck where the decoder
        # decides which parts of the source sequence are relevant
        # for predicting the next token.
        x = self.norm2(
            x + self.dropout(self.cross_attn(x, enc_out, enc_out, src_mask))
        )

        # Applies a non-linear, position-wise transformation
        # to refine the decoder representations after all
        # contextual information has been integrated.
        x = self.norm3(
            x + self.dropout(self.ff(x))
        )

        return x

The decoder first reasons about what it has already generated, then selectively consults the source sequence, and finally refines the result through non-linear transformation — all while strictly preventing future information leakage.

Transformer — Why this architecture works¶

1. Why separate encoder and decoder?¶

The Transformer is designed for sequence-to-sequence tasks (e.g., translation).

  • The encoder reads and understands the entire input sequence.

  • The decoder generates the output sequence one token at a time, conditioned on:

    • what it has already generated, and
    • the encoder’s representation of the input.

This separation cleanly mirrors the probabilistic goal:

Generate an output sequence conditioned on an input sequence.


2. Why stack multiple layers?¶

A single attention layer can only perform one round of information exchange.

By stacking layers:

  • Early layers capture local or shallow relationships
  • Later layers build more abstract, global representations

This hierarchical refinement is analogous to depth in CNNs and deep MLPs.


3. Why positional encoding at the input?¶

Attention alone has no notion of order. Without positional encodings:

["dog", "bites", "man"]
["man", "bites", "dog"]

would look identical to the model.

Adding positional encodings at the input ensures that order information flows through every layer of the network.


4. Why does the encoder run fully in parallel?¶

Unlike RNNs:

  • There is no recurrence
  • No dependency between time steps inside the encoder

This allows:

  • Massive parallelism on GPUs/TPUs
  • Faster training
  • Better scaling to long sequences

This was one of the key motivations behind Attention Is All You Need.

Why project back to vocabulary with fc_out?¶

The model internally operates in a continuous vector space (d_model), but the final task is classification over discrete tokens.

The final linear layer:

  • Converts hidden states into logits
  • One logit per vocabulary token
  • Enables training with cross-entropy loss
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab,
        tgt_vocab,
        d_model=512,
        n_heads=8,
        d_ff=2048,
        n_layers=6,
        max_len=512
    ):
        super().__init__()

        # Source and target token embeddings
        #
        # These map discrete token IDs to continuous vectors of size d_model.
        # Separate embeddings are used because source and target vocabularies
        # may differ and play different roles in the model.
        self.src_embed = nn.Embedding(src_vocab, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab, d_model)

        # Positional encoding
        #
        # Adds explicit sequence order information to token embeddings.
        # This is required because attention itself is permutation-invariant.
        self.pos = PositionalEncoding(d_model, max_len)

        # Encoder stack
        #
        # A stack of identical EncoderBlocks.
        # Each block refines representations by:
        #   1) allowing tokens to attend to each other (self-attention)
        #   2) applying non-linear, position-wise transformations (FFN)
        self.encoder = nn.ModuleList([
            EncoderBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])

        # Decoder stack
        #
        # Each DecoderBlock:
        #   1) builds an autoregressive representation of the target prefix
        #   2) attends to the encoder outputs (cross-attention)
        #   3) refines representations with a feedforward network
        self.decoder = nn.ModuleList([
            DecoderBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])

        # Final linear projection
        #
        # Maps decoder hidden states back to the target vocabulary space,
        # producing logits for next-token prediction.
        self.fc_out = nn.Linear(d_model, tgt_vocab)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Args:
            src: Source token indices (batch_size, src_len)
            tgt: Target token indices (batch_size, tgt_len)
            src_mask: Mask for source padding tokens
            tgt_mask: Causal mask for target tokens

        Returns:
            Logits over target vocabulary for each target position
        """

        # --- Embedding + Positional Encoding ---
        #
        # Token embeddings provide semantic meaning.
        # Positional encodings inject order information.
        #
        # Resulting shape: (batch_size, seq_len, d_model)
        src = self.pos(self.src_embed(src))
        tgt = self.pos(self.tgt_embed(tgt))

        # --- Encoder ---
        #
        # The encoder processes the entire source sequence in parallel.
        # Its output is a contextual representation of the source,
        # which will be attended to by the decoder.
        for layer in self.encoder:
            src = layer(src, src_mask)

        # --- Decoder ---
        #
        # The decoder processes the target sequence autoregressively.
        # At each layer, it:
        #   - attends to past target tokens (masked self-attention)
        #   - attends to encoder outputs (cross-attention)
        for layer in self.decoder:
            tgt = layer(tgt, src, tgt_mask, src_mask)

        # --- Output Projection ---
        #
        # Convert decoder representations into vocabulary logits.
        # These logits are typically passed to a softmax during training
        # or used for greedy/beam search during inference.
        return self.fc_out(tgt)

The Transformer first builds a deep, contextual understanding of the input sequence, then generates the output sequence step by step—each time consulting both its past outputs and the encoded input—using only attention and feedforward layers

Causal Mask¶

1. What problem does the causal mask solve?¶

In autoregressive decoding, the model must obey:

When predicting token t, it can only use tokens 0 … t−1.

Without a causal mask, self-attention would allow every token to see the entire sequence, including future tokens, which would:

  • Break the probabilistic factorization of sequence generation
  • Cause information leakage during training
  • Make inference behavior inconsistent with training

2. Why a lower-triangular matrix?¶

Consider a sequence of length 4:

tokens: y0 y1 y2 y3

The causal mask looks like:

1 0 0 0
1 1 0 0
1 1 1 0
1 1 1 1

Row i corresponds to query position i. Column j corresponds to key position j.

A 1 means “attention allowed”, a 0 means “blocked”.

This structure enforces:

  • Past → allowed
  • Present → allowed
  • Future → blocked

3. Why not just use zeros and ones directly?¶

The mask itself is not applied directly to the output. Instead, it is used inside attention as:

scores = scores.masked_fill(mask == 0, -∞)

This ensures that:

  • Softmax assigns zero probability to future positions
  • Gradients do not flow through masked connections

4. Why add two extra dimensions?¶

Attention scores have shape:

(batch_size, n_heads, query_len, key_len)

By returning a mask of shape:

(1, 1, size, size)

we allow PyTorch to broadcast the same causal structure across:

  • all batches
  • all attention heads

This ensures that causality is enforced globally and consistently.

5. Why is this essential even during training?¶

Even when the full target sequence is known (teacher forcing):

  • The model must behave as if the future is unknown
  • Otherwise, it learns shortcuts that do not exist at inference time

This alignment between training and inference is critical for stable generation.

def causal_mask(size):
    # Create a square matrix of ones with shape (size, size)
    # This represents all possible query–key positions.
    mask = torch.ones(size, size)

    # Keep only the lower-triangular part of the matrix.
    #
    # Positions above the diagonal correspond to "future" tokens
    # (i.e., positions j > i when predicting token i).
    #
    # torch.tril enforces the causal constraint:
    #   token i can attend to tokens {0, ..., i}
    #   but NOT to tokens {i+1, ..., size-1}
    mask = torch.tril(mask)

    # Add two singleton dimensions so the mask can be broadcast
    # across:
    #   - batch dimension
    #   - attention heads
    #
    # Final shape: (1, 1, size, size)
    # This matches attention score tensors of shape:
    #   (batch_size, n_heads, seq_len, seq_len)
    return mask.unsqueeze(0).unsqueeze(1)

The causal mask enforces the rule “no peeking into the future” by blocking attention to tokens that haven’t been generated yet.

model = Transformer(
    src_vocab=10000,
    tgt_vocab=10000,
    d_model=512,
    n_heads=8,
    n_layers=6
)

src = torch.randint(0, 10000, (2, 20))
tgt = torch.randint(0, 10000, (2, 20))

mask = causal_mask(tgt.size(1))
out = model(src, tgt, tgt_mask=mask)

print(out.shape)  # (batch, seq_len, vocab)
out

So... How do we go about generating Text¶

Transformers do not generate text, they generate next-token probability distributions, you decide how to sample from them.

Decoder-Only vs Encoder-Decoder¶

Some LLMs are encoder–decoder, but most modern general-purpose LLMs are decoder-only. The split is very clean once you classify them by what probability they model.

You should always start by asking: what probability distribution am I trying to learn?

For TinyStories, the goal is to model P(x₁, x₂, …, x_T) which factorizes as P(x₁) · P(x₂ | x₁) · P(x₃ | x₁, x₂) · …

This is exactly an autoregressive process: each token depends only on the tokens before it.

A decoder-only transformer is built precisely to model P(x_t | x₁ … x_{t−1}) using masked self-attention, which prevents the model from seeing future tokens. The architecture directly enforces the correct causal structure.


An encoder–decoder transformer models a different distribution: P(y | x)

That only makes sense when there is a separate input sequence x, such as:

  • translation (French given English)
  • summarization (summary given document)
  • QA (answer given question)

TinyStories has no separate input. There is nothing to encode. Adding an encoder would give the model access to the entire story during training, which breaks the causal assumption and creates a mismatch between training and generation.

Decoder-only models also guarantee training–inference alignment:

  • during training, the model only sees past tokens
  • during generation, the model only sees past tokens

Encoder–decoder models cannot guarantee this alignment for language modeling without artificial constraints.

So the reason is not “because GPT does it”.

The real reason is:

Architecture follows probability factorization.

If you want to model P(x), use a decoder-only transformer. If you want to model P(y|x), use an encoder–decoder transformer.

Configuration¶

! curl "https://cas-bridge.xethub.hf.co/xet-bridge-us/645e8da96320b0efe40ade7a/02e40cc51c59a4bc6c51bd7bc9acda4316e208745be060558eaf500cd14e9f96?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20260112%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20260112T124634Z&X-Amz-Expires=3600&X-Amz-Signature=a948d01680e37e90762ec67ccadf0d597a9a94e89dd28da762376ed10ed60b41&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=6507eb42423b46492edf979c&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27TinyStoriesV2-GPT4-train.txt%3B+filename%3D%22TinyStoriesV2-GPT4-train.txt%22%3B&response-content-type=text%2Fplain&x-id=GetObject&Expires=1768225594&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2ODIyNTU5NH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2FzLWJyaWRnZS54ZXRodWIuaGYuY28veGV0LWJyaWRnZS11cy82NDVlOGRhOTYzMjBiMGVmZTQwYWRlN2EvMDJlNDBjYzUxYzU5YTRiYzZjNTFiZDdiYzlhY2RhNDMxNmUyMDg3NDViZTA2MDU1OGVhZjUwMGNkMTRlOWY5NioifV19&Signature=BPYTyJ2CTFy3Jb5MCHF4i%7EEAndHrAJP8bK6GP7tv6REvFarHH3eAur3dyE6w-7eo7PJGKzzeQDodxSYhwHQE95b2RuywZ5DlxTS%7EelkvlI52suIS6vgxa2bkGq5sW7zD1LAzuP3UEoJ1mniA7vq8WbQ2OPWKy%7ET87Zc8ieGiMZ7KoEOy4OpiUhY3SiU3e%7EI43nHwlcEQEGQ4VpRG5OlQNEnOSwSUhm4UlHIvz6gt3smSOCvgCV7l3MTY0CpPU00YwTp7w0NbIaHsTMuuk8N0XBVeAl%7EdE0o1qs3RSZeUeY2grJbVaxlevb657i0R%7E1uHDZ1%7E-ctEGcXRXpMcLqvlkw__&Key-Pair-Id=K2L8F4GPSG1IFC" -o TinyStoriesV2-GPT4-train.txt
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 2124M  100 2124M    0     0   128M      0  0:00:16  0:00:16 --:--:-- 66.8M
! head -c 100 TinyStoriesV2-GPT4-train.txt
Once upon a time there was a little boy named Ben. Ben loved to explore the world around him. He sa
import torch

DATA_PATH = "TinyStoriesV2-GPT4-train.txt"  # change if needed
MAX_CHARS = 500_000      # limit for quick runs
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BLOCK_SIZE = 512
BATCH_SIZE = 128   # or 256 if it fits
GRAD_ACCUM = 1
D_MODEL    = 256
N_HEADS   = 4
N_LAYERS  = 4
D_FF      = 1024
LR        = 3e-4
EPOCHS    = 80
AMP       = True

Dataset¶

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

LanguageModelDataset (TinyStories)¶

Historical context¶

Early language models (n-gram models, HMMs) modeled text as a probability distribution over sequences using the factorization:

The probability of a sentence is the product of the probability of each word given all previous words.

Neural language models (Bengio et al., 2003) adopted the same formulation, and this objective was later carried over to RNNs, LSTMs, and eventually Transformers.

When Transformer-based language models (Radford et al., 2018) removed recurrence entirely, they still preserved this exact probabilistic structure. The only requirement was a dataset that could generate many examples of:

“Given a sequence of previous tokens, predict the next token.”

This dataset class exists to construct that learning signal from raw tokenized text such as TinyStories.


Why sliding windows?¶

TinyStories is provided as a long stream of tokens, not pre-segmented into independent training samples.

The design choice here is to use a sliding window over the token stream:

  • Each window provides a fixed-length context (block_size)
  • The target is the same window shifted by one token

This idea predates Transformers and was used in:

  • Feedforward neural language models
  • RNN-based language models
  • GPT-style models

It ensures:

  • Efficient use of all tokens
  • Strong local coherence learning
  • Compatibility with causal self-attention

Why predict a sequence instead of a single token?¶

Although the objective is “predict the next token,” the model predicts the next token at every position simultaneously.

This aligns with:

  • Parallel training (a core advantage of Transformers)
  • Causal masking in self-attention
  • Efficient GPU utilization

Each training example teaches the model multiple prediction tasks in one forward pass.

class LanguageModelDataset(Dataset):
    def __init__(self, data, block_size):
        # `data` is a long sequence of token IDs representing TinyStories.
        # Example:
        #   [12, 45, 891, 34, 78, ...]
        #
        # `block_size` defines the length of the context window the model
        # is allowed to condition on when predicting the next token.
        self.data = data
        self.block_size = block_size

    def __len__(self):
        # Each training example requires `block_size + 1` tokens:
        #   - `block_size` tokens for the input
        #   - 1 additional token for the shifted target
        #
        # We subtract `block_size` to avoid indexing past the end
        # of the token sequence.
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        # Input sequence (context window)
        #
        # This represents what the model is allowed to see.
        # Shape: (block_size,)
        x = self.data[idx : idx + self.block_size]

        # Target sequence (next-token labels)
        #
        # This is the same sequence shifted one position to the right.
        # For each position t, the model learns to predict y[t]
        # given x[0:t].
        y = self.data[idx + 1 : idx + self.block_size + 1]

        return x, y

This dataset turns TinyStories into overlapping context–next-token prediction tasks, enforcing the same left-to-right learning objective that underpins all modern autoregressive language models.

Tokenizer¶

BPE (Byte Pair Encoding) Tokenization & Data Pipeline (TinyStories)¶

Why BPE tokenization?¶

This code uses Byte Pair Encoding (BPE) tokenization, a subword tokenization approach that balances vocabulary size and sequence length.

Historically:

  • BPE was introduced for text compression (Gage, 1994) and later adapted for NLP (Sennrich et al., 2016)

  • Modern language models (GPT, BERT, etc.) use subword tokenization like BPE to:

    • handle rare and unknown words gracefully
    • reduce vocabulary size compared to word-level tokenization
    • maintain reasonable sequence lengths
  • BPE became popular in NLP through works like:

    • Sennrich et al., 2016 – "Neural Machine Translation of Rare Words with Subword Units"

For TinyStories, BPE tokenization is a good choice because:

  • It handles the full vocabulary efficiently
  • Reduces sequence length compared to character-level tokenization
  • Provides a more realistic setup similar to production language models

Why encode the entire corpus into one long tensor?¶

data = torch.tensor(tokenizer.encode(text), dtype=torch.long)

Instead of treating each story independently, the dataset is represented as a single continuous stream of tokens.

This aligns with:

  • Autoregressive language modeling theory
  • Sliding-window dataset construction
  • GPT-style training pipelines

The model learns:

“Text is a continuous stream where any position can follow any other position.”

This assumption simplifies training and works well in practice.


Why use LanguageModelDataset + DataLoader?¶

This combination:

  • Converts the token stream into overlapping (x, y) training pairs
  • Enables batching, shuffling, and parallel loading
  • Decouples data logic from model logic

Historically, this design emerged as PyTorch standardized:

  • Dataset for data definition
  • DataLoader for efficient iteration

Why shuffle and drop last?¶

shuffle=True
drop_last=True
  • Shuffling prevents the model from seeing data in a fixed order, improving generalization

  • Dropping the last batch ensures consistent batch sizes, which simplifies:

    • tensor shapes
    • attention masking
    • GPU efficiency
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True

class BPETokenizer:
    def __init__(self, text, vocab_size=1000):
        self.vocab_size = vocab_size

        # Start with character-level vocabulary
        chars = sorted(list(set(text)))
        self.vocab = chars.copy()

        # Create initial merges (character pairs)
        self.merges = {}

        # Convert text to list of tokens (initially characters)
        tokens = list(text)

        # Learn BPE merges
        for _ in range(vocab_size - len(chars)):
            # Count frequency of each adjacent pair
            pair_freqs = {}
            for i in range(len(tokens) - 1):
                pair = (tokens[i], tokens[i + 1])
                pair_freqs[pair] = pair_freqs.get(pair, 0) + 1

            if not pair_freqs:
                break

            # Find most frequent pair
            best_pair = max(pair_freqs, key=pair_freqs.get)

            # Create new token for this pair
            new_token = ''.join(best_pair)
            self.merges[best_pair] = new_token
            self.vocab.append(new_token)

            # Merge all occurrences of this pair in the token list
            i = 0
            while i < len(tokens) - 1:
                if (tokens[i], tokens[i + 1]) == best_pair:
                    tokens[i] = new_token
                    del tokens[i + 1]
                else:
                    i += 1

        # Create mappings
        self.stoi = {token: i for i, token in enumerate(self.vocab)}
        self.itos = {i: token for token, i in self.stoi.items()}

    def encode(self, s):
        # Start with character-level encoding
        tokens = list(s)

        # Apply BPE merges greedily
        changed = True
        while changed:
            changed = False
            i = 0
            while i < len(tokens) - 1:
                pair = (tokens[i], tokens[i + 1])
                if pair in self.merges:
                    tokens[i] = self.merges[pair]
                    del tokens[i + 1]
                    changed = True
                else:
                    i += 1

        return [self.stoi[token] for token in tokens]

    def decode(self, ids):
        # Convert IDs back to tokens
        tokens = [self.itos[i] for i in ids]
        return ''.join(tokens)
print("Loading data...")
with open(DATA_PATH, "r", encoding="utf-8") as f:
    text = f.read()[:MAX_CHARS]

tokenizer = BPETokenizer(text, vocab_size=1000)
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)

# ---------------- SPLIT ----------------
split_ratio = 0.9
split_idx = int(len(data) * split_ratio)

train_data = data[:split_idx]
val_data   = data[split_idx:]

# ---------------- DATASETS ----------------
train_dataset = LanguageModelDataset(train_data, BLOCK_SIZE)
val_dataset   = LanguageModelDataset(val_data, BLOCK_SIZE)

# ---------------- LOADERS ----------------
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=3,
    shuffle=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
     num_workers=2,
    shuffle=False,      # IMPORTANT
    drop_last=True
)
Loading data...
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 3 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(

This pipeline converts raw TinyStories text into a stream of BPE subword tokens and then into overlapping context–next-token prediction tasks, exactly matching the autoregressive objective used by GPT-style language models.

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

GPT¶

MultiAttentionHead¶

Differences vs earlier attention (only real changes)¶

1. Self-attention only (no Q, K, V inputs)¶

Earlier

forward(self, Q, K, V, mask=None)

Now

forward(self, x, mask=None)

What changed

  • Q, K, V are all derived from the same tensor x
  • This is pure self-attention only

Why

  • GPT-style decoder-only models never use cross-attention
  • Simplifies the API and reduces surface area for bugs
  • Matches how GPT actually operates

2. Attention logic is inlined (no ScaledDotProductAttention module)¶

Earlier

self.attn = ScaledDotProductAttention(self.d_k)
x, _ = self.attn(Q, K, V, mask)

Now

scores = (q @ k.transpose(-2, -1)) / sqrt(d_k)
attn = softmax(scores)
out = attn @ v

What changed

  • Attention computation is written inline
  • No separate abstraction

Why

  • GPT-style minimalism
  • Fewer function calls
  • Easier to fuse / optimize later
  • Matches most real GPT codebases

3. Causal masking assumed externally¶

Earlier

  • Masking semantics varied (encoder mask, decoder mask, src/tgt)

Now

scores = scores.masked_fill(mask == 0, -1e9)

What changed

  • This module assumes the mask is already causal and broadcastable
  • No logic to construct or interpret masks internally

Why

  • Cleaner separation of concerns
  • GPT always uses causal masking
  • Mask construction belongs to the model, not attention

4. No attention weights returned¶

Earlier

return output, attn

Now

return self.W_o(out)

What changed

  • Attention weights are discarded

Why

  • GPT training never uses attention maps
  • Saves memory
  • Reduces bandwidth and overhead

5. Single-path projection (GPT style)¶

This version:

  • Projects Q/K/V from the same input
  • Concatenates heads
  • Applies one output projection

This matches GPT-1 → GPT-4 style attention exactly.


What did not change (important)¶

  • Scaled dot-product attention
  • Multi-head splitting
  • Softmax over keys
  • Residual compatibility
  • Parallel attention computation

Only the scope and structure changed, not the math.

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiHeadAttention(nn.Module):
    """
    GPT-style multi-head self-attention
    """

    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        # Fused projection for Q, K, V (faster + cleaner)
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)

        # Output projection
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
        """
        Args:
            x: (B, T, d_model)
            mask: (1 or B, 1 or n_heads, T, T) causal mask (0 = block)

        Returns:
            (B, T, d_model)
        """
        B, T, C = x.shape

        # ------------------------------------------------
        # Project & split into Q, K, V
        # ------------------------------------------------
        # (B, T, 3 * C) → (B, T, 3, n_heads, d_k)
        qkv = self.qkv_proj(x).view(B, T, 3, self.n_heads, self.d_k)

        # (B, n_heads, T, d_k)
        q, k, v = qkv.unbind(dim=2)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # ------------------------------------------------
        # Scaled dot-product attention
        # ------------------------------------------------
        # (B, n_heads, T, T)
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))

        attn = F.softmax(scores, dim=-1)

        # (B, n_heads, T, d_k)
        out = attn @ v

        # ------------------------------------------------
        # Recombine heads
        # ------------------------------------------------
        # (B, T, C)
        out = out.transpose(1, 2).contiguous().view(B, T, C)

        return self.out_proj(out)

This version removes generality (cross-attention, modular abstractions) in favor of a minimal, self-attention–only implementation that matches how GPT actually works in practice.

Earlier Transformer (Vaswani et al., 2017)¶

  • Designed for sequence-to-sequence tasks (e.g., translation)

  • Architecture:

    • Encoder stack
    • Decoder stack
    • Cross-attention between them

GPT (Radford et al., 2018 → GPT-4)¶

  • Designed for language modeling and generation

  • Architecture:

    • Decoder-only Transformer
    • No encoder
    • No cross-attention

Why the change?

Language modeling only requires:

“Predict the next token given all previous tokens.”

There is no source sequence to condition on, so:

  • Encoder becomes unnecessary
  • Cross-attention is removed
  • Model becomes simpler and more scalable

This simplification was one of the key insights behind GPT.


2. Architectural evolution: Pre-LN instead of Post-LN¶

Earlier code (Encoder / Decoder blocks)¶

we used post-norm:

x = LayerNorm(x + sublayer(x))

GPT-style block¶

This implementation uses pre-norm:

x = x + sublayer(LayerNorm(x))

Why this change?

Research showed that:

  • Post-LN Transformers become unstable at depth
  • Gradients struggle to flow through many layers

Pre-LN:

  • Improves gradient flow
  • Allows very deep models (100+ layers)
  • Is now standard in GPT, LLaMA, PaLM, etc.

Key reference: Xiong et al., On Layer Normalization in Transformers, ICML 2020


3. Single attention type: masked self-attention only¶

Earlier:

  • Encoder self-attention
  • Decoder masked self-attention
  • Decoder cross-attention

GPT:

  • Only masked self-attention
self.attn = MultiHeadAttention(d_model, n_heads)

with:

mask = causal_mask(T)

Why?

  • GPT models a single sequence
  • Causality enforces correct generation
  • Cross-attention is unnecessary

4. Weight tying (embedding ↔ output projection)¶

self.head.weight = self.embed.weight

This was not in the original Transformer, but introduced and popularized in:

  • Press & Wolf (2017)
  • GPT-1 onward

Why weight tying?

  • Reduces parameter count

  • Improves generalization

  • Enforces symmetry between:

    • “reading” a token (embedding)
    • “writing” a token (logits)

This is now standard practice.


5. Final LayerNorm before output¶

self.ln_f = nn.LayerNorm(D_MODEL)

This final normalization layer:

  • Stabilizes logits
  • Improves training dynamics
  • Became standard in GPT-style models

Earlier Transformers normalized inside blocks only.


6. Positional encoding remains (but later evolved)¶

This code still uses sinusoidal positional encoding, but historically:

  • GPT-1 / GPT-2 → learned positional embeddings
  • GPT-3+ → RoPE / variants
  • LLaMA → RoPE
  • ALiBi → attention bias instead of embeddings

But the conceptual role remains unchanged:

Inject order into an otherwise permutation-invariant attention mechanism.

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()

        # Masked multi-head self-attention
        # Uses causal masking to prevent access to future tokens
        self.attn = MultiHeadAttention(d_model, n_heads)

        # Position-wise feedforward network
        self.ff = FeedForward(d_model, d_ff)

        # Pre-layer normalization
        # Improves gradient flow in deep networks
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        # Pre-LN + residual connection for attention
        # x ← x + Attention(LayerNorm(x))
        x = x + self.attn(self.ln1(x), mask)

        # Pre-LN + residual connection for feedforward
        # x ← x + FFN(LayerNorm(x))
        x = x + self.ff(self.ln2(x))

        return x
class GPT(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()

        # Token embedding layer
        # Maps token IDs to continuous vectors
        self.embed = nn.Embedding(vocab_size, D_MODEL)

        # Positional encoding to inject order information
        self.pos = PositionalEncoding(D_MODEL, BLOCK_SIZE)

        # Stack of Transformer blocks
        # Each block refines representations autoregressively
        self.blocks = nn.ModuleList([
            TransformerBlock(D_MODEL, N_HEADS, D_FF)
            for _ in range(N_LAYERS)
        ])

        # Final LayerNorm before output projection
        self.ln_f = nn.LayerNorm(D_MODEL)

        # Output projection to vocabulary space
        self.head = nn.Linear(D_MODEL, vocab_size)

        # Weight tying between input embeddings and output logits
        # Reduces parameters and improves generalization
        self.head.weight = self.embed.weight

    def forward(self, x):
        """
        Args:
            x: Input token IDs (batch_size, seq_len)

        Returns:
            Logits over vocabulary for each position
        """

        B, T = x.shape

        # Create causal mask to enforce autoregressive constraint
        mask = causal_mask(T).to(x.device)

        # Token embeddings + positional encodings
        x = self.embed(x)
        x = self.pos(x)

        # Apply stacked Transformer blocks
        for block in self.blocks:
            x = block(x, mask)

        # Final normalization and projection to logits
        x = self.ln_f(x)
        return self.head(x)

GPT removes the encoder, enforces causality everywhere, stabilizes training with pre-norm, and scales depth—turning the Transformer into a pure next-token prediction machine.

Training¶

model = GPT(tokenizer.vocab_size).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
import torch, gc

gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

print(torch.cuda.memory_summary())
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  13858 KiB |  13858 KiB |  13858 KiB |      0 B   |
|       from large pool |      0 KiB |      0 KiB |      0 KiB |      0 B   |
|       from small pool |  13858 KiB |  13858 KiB |  13858 KiB |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |  13858 KiB |  13858 KiB |  13858 KiB |      0 B   |
|       from large pool |      0 KiB |      0 KiB |      0 KiB |      0 B   |
|       from small pool |  13858 KiB |  13858 KiB |  13858 KiB |      0 B   |
|---------------------------------------------------------------------------|
| Requested memory      |  13857 KiB |  13857 KiB |  13857 KiB |      0 B   |
|       from large pool |      0 KiB |      0 KiB |      0 KiB |      0 B   |
|       from small pool |  13857 KiB |  13857 KiB |  13857 KiB |      0 B   |
|---------------------------------------------------------------------------|
| GPU reserved memory   |  14336 KiB |  14336 KiB |  14336 KiB |      0 B   |
|       from large pool |      0 KiB |      0 KiB |      0 KiB |      0 B   |
|       from small pool |  14336 KiB |  14336 KiB |  14336 KiB |      0 B   |
|---------------------------------------------------------------------------|
| Non-releasable memory | 489472 B   |   1816 KiB |   7704 KiB |   7226 KiB |
|       from large pool |      0 B   |      0 KiB |      0 KiB |      0 KiB |
|       from small pool | 489472 B   |   1816 KiB |   7704 KiB |   7226 KiB |
|---------------------------------------------------------------------------|
| Allocations           |      53    |      53    |      53    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |      53    |      53    |      53    |       0    |
|---------------------------------------------------------------------------|
| Active allocs         |      53    |      53    |      53    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |      53    |      53    |      53    |       0    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       7    |       7    |       7    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       7    |       7    |       7    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       2    |       4    |       7    |       5    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       2    |       4    |       7    |       5    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

import torch
from torch.cuda.amp import autocast, GradScaler

# Compile model (PyTorch 2.x)
model = torch.compile(model)

scaler = GradScaler()

ACC_STEPS = 4          # 🔧 tune this
CLIP_NORM = 1.0
/tmp/ipython-input-1351142244.py:7: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = GradScaler()
@torch.no_grad()
def generate(prompt, max_new_tokens=300, temperature=0.8):
    model.eval()
    ids = torch.tensor([tokenizer.encode(prompt)], device=DEVICE)

    for _ in range(max_new_tokens):
        # Truncate input to BLOCK_SIZE if it exceeds it.
        # The model was trained with BLOCK_SIZE context.
        input_ids = ids[:, -BLOCK_SIZE:]
        logits = model(input_ids)
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, 1)
        ids = torch.cat([ids, next_id], dim=1)

    return tokenizer.decode(ids[0].tolist())
from tqdm import tqdm
import torch
import math
import os
from torch.cuda.amp import autocast, GradScaler

# ============================================================
# RESUME FROM CHECKPOINT (FULL STATE)
# ============================================================

best_val_loss = math.inf
start_epoch = 0
scaler = GradScaler()



if os.path.exists("best_model.pt"):
    print("Resuming from best_model.pt")
    ckpt = torch.load("best_model.pt", map_location=DEVICE)

    if isinstance(ckpt, dict) and "model" in ckpt:
        # ✅ NEW-style checkpoint
        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scaler.load_state_dict(ckpt["scaler"])
        best_val_loss = ckpt["best_val_loss"]
        start_epoch = ckpt["epoch"] + 1
    else:
        # ✅ OLD-style checkpoint (weights only)
        model.load_state_dict(ckpt)
        print("Loaded weights-only checkpoint (optimizer/scaler reset)")

model.to(DEVICE)

# ============================================================
# TRAINING LOOP
# ============================================================

for epoch in range(start_epoch, EPOCHS):

    # ===================== TRAIN =====================
    model.train()
    optimizer.zero_grad(set_to_none=True)
    train_loss = 0.0

    for step, (x, y) in enumerate(tqdm(train_loader, desc=f"Train {epoch+1}")):
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)

        with autocast():
            logits = model(x)
            loss = criterion(
                logits.view(-1, logits.size(-1)),
                y.view(-1)
            )
            loss = loss / ACC_STEPS

        scaler.scale(loss).backward()
        train_loss += loss.item() * ACC_STEPS

        should_step = (
            (step + 1) % ACC_STEPS == 0
            or (step + 1) == len(train_loader)  # flush last batch
        )

        if should_step:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

    train_loss /= len(train_loader)

    # ===================== VALIDATE =====================
    model.eval()
    val_loss = 0.0
    MAX_VAL_BATCHES = 200

    with torch.no_grad(), autocast():
        for i, (x, y) in enumerate(val_loader):
            if i >= MAX_VAL_BATCHES:
                break

            x = x.to(DEVICE, non_blocking=True)
            y = y.to(DEVICE, non_blocking=True)

            logits = model(x)
            loss = criterion(
                logits.view(-1, logits.size(-1)),
                y.view(-1)
            )
            val_loss += loss.item()

    val_loss /= min(len(val_loader), MAX_VAL_BATCHES)

    # ===================== LOGGING + CHECKPOINT =====================
    improved = val_loss < best_val_loss
    if improved:
        best_val_loss = val_loss
        torch.save(
            {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scaler": scaler.state_dict(),
                "epoch": epoch,
                "best_val_loss": best_val_loss,
            },
            "best_model.pt",
        )

    print(
        f"Epoch {epoch+1:3d} | "
        f"Train {train_loss:.4f} | "
        f"Val {val_loss:.4f}"
        + ("  ✓" if improved else "")
    )

    # ===================== SAMPLE =====================
    if (epoch + 1) % 10 == 0:
        model.eval()
        prompt = torch.tensor([[tokenizer.stoi.get("O", 0)]], device=DEVICE)
        gen = model.generate(prompt, max_tokens=100)
        print(f"\nSample:\n{tokenizer.decode(gen[0].tolist())}\n")

# ============================================================
# DONE
# ============================================================

print("\n" + "=" * 50)
print(f"Training complete! Best val loss: {best_val_loss:.4f}")
print("=" * 50)
/tmp/ipython-input-960699167.py:13: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = GradScaler()
Resuming from best_model.pt
Loaded weights-only checkpoint (optimizer/scaler reset)
Train 1:   0%|          | 0/1539 [00:00<?, ?it/s]/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 3 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/tmp/ipython-input-960699167.py:50: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with autocast():
Train 1: 100%|██████████| 1539/1539 [04:39<00:00,  5.51it/s]
/tmp/ipython-input-960699167.py:81: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.no_grad(), autocast():
Epoch   1 | Train 2.9840 | Val 2.8299  ✓
Train 2: 100%|██████████| 1539/1539 [04:47<00:00,  5.36it/s]
Epoch   2 | Train 2.0371 | Val 2.9852
Train 3: 100%|██████████| 1539/1539 [04:46<00:00,  5.37it/s]
Epoch   3 | Train 1.2778 | Val 4.0615
Train 4: 100%|██████████| 1539/1539 [04:46<00:00,  5.38it/s]
Epoch   4 | Train 0.7066 | Val 5.6822
Train 5: 100%|██████████| 1539/1539 [04:45<00:00,  5.40it/s]
Epoch   5 | Train 0.3048 | Val 7.1991
Train 6:   5%|▍         | 74/1539 [00:13<04:36,  5.31it/s]

Generation¶

Drumroll ....¶

print("\n--- GENERATED STORY ---\n")
print(generate("Once upon a time  "))
--- GENERATED STORY ---

Once upon a time  was tirered. It made go to the park, but he also lion named sed sed with a new friend, a little girl named Lily him. She saw her frog and opped her head. Lily was very munice. Then, She camemmembered Mom comineace braceful was not too too happy ve fun.
<|endoftext|>
Once upon a time, there was a little girl named Mia. Mia had a big, red in a little house with her  always smiled. Mia truckly little cat the ad. Lily lived in a small house with her momy. Mia loved to stay very happy.
Mia went outside to play in the stawn and saw a lotched the slide. She was very happy and s. They cyedguit. The slizzlep. Mia could not find it was very happy. The kids girl 

A 3.1M-parameter model is very small by modern language-model standards. For context, classic “toy” GPT implementations used for learning typically range from 1–10M parameters, which is enough to capture basic token statistics and short-range patterns but not sustained coherence. GPT-2 Small already has 117M parameters (≈40× larger), GPT-2 Medium 345M, GPT-2 Large 774M, and GPT-2 XL 1.5B. Modern production LLMs operate in the tens to hundreds of billions of parameters. Practically, this means a 3.1M model is expected to produce locally sensible text but struggle with long-term consistency, reasoning, and abstraction

This is ideal for understanding how transformers work, not for demonstrating strong language capability.