GPT From Scratch - A Simple Primer¶
Hi! I am Srihari Unnikrishnan(@pythongiant) and this is a small primer on transformers, encoder-decoder transformers and generating text with a decoder-only transformer.
Send me an email at srihari[dot]unnikrishnan[at]gmail[dot]com if you have any questions or any suggestions/recommendations!
I would be more than glad to help you out.
Encoder - Decoder Transformer¶
An Encoder–Decoder Transformer is the original Transformer architecture introduced in Attention Is All You Need (Vaswani et al., 2017). It is designed to model conditional sequence generation.
Why encoder–decoder lost dominance¶
Encoder–decoder models were very strong around 2018–2020. Decoder-only models won out because they:
Scale better (simpler architecture, fewer attention paths)
Have perfect training–inference alignment
Can absorb conditional tasks via prompting
Are easier to use as general-purpose models
Enable efficient KV caching during generation
In practice, decoder-only models learned to simulate encoder–decoder behavior by treating the “input” as part of the prefix
Architecture¶

Positional Embeddings¶
Positional embeddings add explicit numerical information about the order (position) of elements in a sequence to their vector representations. This is necessary because Transformer models process all tokens in parallel using self-attention and therefore have no inherent notion of sequence order. Without positional information, a Transformer would treat a sentence as a set of tokens rather than an ordered sequence.
To solve this, each position in a sequence (e.g., 1st word, 2nd word, 3rd word, etc.) is assigned a position-specific vector. This positional vector is added to the token’s embedding before being passed into the model. As a result, the final input representation contains both:
What the token is (semantic content from the word embedding), and
Where the token appears (its position in the sequence).
This allows the model to distinguish between sequences with the same words but different orderings, such as “dog bites man” versus “man bites dog”, where meaning critically depends on position.
The example shows fixed, absolute sinusoidal positional embeddings as introduced in Attention Is All You Need.
# Position Encoding
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__() # Initialize nn.Module internals
# Create a (max_len x d_model) matrix to store positional encodings
pe = torch.zeros(max_len, d_model)
# Create position indices: [0, 1, 2, ..., max_len-1]
# Shape: (max_len, 1)
position = torch.arange(0, max_len).unsqueeze(1)
# Compute the frequency scaling term for each even dimension
# Shape: (d_model / 2,)
#
# Each pair of embedding dimensions (2i, 2i+1) shares a unique frequency.
# These frequencies are exponentially spaced so that:
# - Early dimensions oscillate quickly (high-frequency signals),
# capturing fine-grained, local positional differences.
# - Later dimensions oscillate slowly (low-frequency signals),
# capturing long-range, global positional structure.
#
# This multi-scale design allows the model to represent both short- and
# long-distance relationships between tokens.
#
# The exponential scaling (10000^(-2i / d_model)) ensures:
# - Positional encodings remain unique across long sequences
# - Relative position information can be recovered using linear operations,
# which is critical for attention mechanisms.
#
# Using fixed (non-learned) frequencies avoids overfitting to training
# sequence lengths and enables generalization to longer sequences.
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
# Apply sine to even embedding dimensions (0, 2, 4, ...)
# Broadcasting: (max_len, 1) * (d_model/2,) -> (max_len, d_model/2)
pe[:, 0::2] = torch.sin(position * div_term)
# If positional encodings were different per batch element, the model could learn:
# “this sentence is in batch slot 3”
# “batch index correlates with label”
# That would be spurious information and break generalization. Hence we ensure: No batch-specific signal and that only relative token order is encoded
# Apply cosine to odd embedding dimensions (1, 3, 5, ...)
pe[:, 1::2] = torch.cos(position * div_term)
# Add a batch dimension so it can be broadcast across batches
# Final shape: (1, max_len, d_model)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
"""
Args:
x: Input embeddings of shape (batch_size, seq_len, d_model)
Returns:
Embeddings with positional information added
"""
# Slice positional encodings to match sequence length
# Broadcasting adds the same positional encoding to each batch element
return x + self.pe[:, :x.size(1)]
Importantly, positional embeddings depend only on the token’s position within a sequence, not on which sequence in the batch it belongs to. Therefore, the same positional embedding for position t is shared across all sequences in a batch and is broadcast during computation. This ensures that positional information encodes relative and absolute order, rather than introducing spurious batch-specific signals.
There are 2 types of positional embeddings
Fixed (sinusoidal), where positions are encoded using sine and cosine functions of different frequencies, enabling the model to generalize to longer sequences and infer relative distances between tokens.
Learned, where positional vectors are trained parameters similar to word embeddings.
In both cases, positional embeddings are combined with token embeddings via element-wise addition, preserving dimensionality while enriching each token representation with sequence order information. This simple design enables Transformers to model syntax, dependencies, and context over sequences effectively, despite their parallel processing nature.
Scaled Dot-Product Attention¶
Scaled dot-product attention is the core mechanism that allows Transformer models to selectively focus on the most relevant parts of a sequence when processing each token. Instead of treating all tokens equally, attention computes how strongly one token should “attend” to every other token based on learned representations.
Each token is represented using three vectors:
- Query (Q): what this token is looking for
- Key (K): what each token offers
- Value (V): the information each token contains
Attention works by comparing each query with all keys using a dot product, producing a set of similarity scores that measure relevance. A higher score means the key token is more relevant to the query token. These scores are then normalized into a probability distribution and used to compute a weighted sum of the value vectors, producing a context-aware representation of each token.
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_k):
super().__init__()
# Scaling factor = sqrt(d_k)
#
# The dot product Q · K grows in magnitude with the dimensionality d_k.
# Without scaling, large dot products would push the softmax into
# extremely peaked distributions, causing:
# - Vanishing gradients
# - Overconfident, brittle attention weights
#
# Dividing by sqrt(d_k) keeps the variance of the scores roughly constant,
# stabilizing training and ensuring smoother attention distributions.
self.scale = math.sqrt(d_k)
def forward(self, Q, K, V, mask=None):
"""
Args:
Q: Queries of shape (..., seq_len_q, d_k)
K: Keys of shape (..., seq_len_k, d_k)
V: Values of shape (..., seq_len_k, d_v)
mask: Optional attention mask (broadcastable to scores shape)
Used to block padding tokens or future tokens (causal masking)
Returns:
output: Attention-weighted values
attn: Attention weight matrix
"""
# Compute raw attention scores via dot product:
# score_{i,j} = Q_i · K_j
#
# K is transposed so that:
# (..., seq_len_q, d_k) @ (..., d_k, seq_len_k)
# → (..., seq_len_q, seq_len_k)
#
# Each score measures how much token i should attend to token j.
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# Apply mask (if provided) to prevent attention to certain positions:
# - Padding tokens (in encoder attention)
# - Future tokens (in decoder self-attention)
#
# Masked positions are set to a large negative value so that
# softmax assigns them near-zero probability.
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Normalize scores into a probability distribution using softmax.
# For each query token, attention weights over all key positions sum to 1.
attn = torch.softmax(scores, dim=-1)
# Compute the final attention output as a weighted sum of values:
# output_i = Σ_j attn_{i,j} · V_j
#
# This mixes information from all tokens, weighted by relevance.
return torch.matmul(attn, V), attn
A critical design choice is the scaling factor. As the dimensionality of the key vectors increases, the magnitude of dot products grows, which would cause the softmax function to become overly peaked. This would make the model excessively confident in a small number of tokens and lead to vanishing gradients during training. Dividing by sqrt(d_K) stabilizes the distribution of attention scores, ensuring smooth gradients and more balanced attention across tokens.
Another key design element is masking, which allows the model to control where attention is permitted. Masks prevent attention to padding tokens (which carry no information) and, in autoregressive settings, prevent tokens from attending to future positions. This enforces the correct information flow while keeping the attention mechanism fully parallelizable.
By combining similarity-based scoring, variance-stabilizing scaling, and flexible masking, scaled dot-product attention enables Transformers to model long-range dependencies, contextual meaning, and structured information flow efficiently. Importantly, this mechanism is fully differentiable, parallelizable, and independent of sequence length assumptions, making it a foundational building block of modern sequence models.
MultiHeadAttention¶
Multi-head attention extends scaled dot-product attention by allowing the model to attend to different types of relationships simultaneously. Instead of computing a single attention distribution over the entire embedding space, the model splits the representation into multiple lower-dimensional subspaces called heads and performs attention independently in each one.
Each head learns to focus on a different aspect of the sequence. For example, one head may specialize in short-range syntactic relationships (such as subject–verb agreement), while another may focus on long-range semantic dependencies or positional structure. This decomposition allows the model to capture richer and more diverse patterns than a single attention mechanism could.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
# d_model must be divisible by the number of heads
# because we split the embedding space evenly across heads.
# Each head operates on a lower-dimensional subspace.
assert d_model % n_heads == 0
# Dimensionality per head (d_k)
# Total dimension is preserved when heads are concatenated back.
self.d_k = d_model // n_heads
self.n_heads = n_heads
# Linear projections for Queries, Keys, and Values
#
# These allow the model to learn different representations
# of the same input depending on how it is being used:
# - As a query (what am I looking for?)
# - As a key (what do I offer?)
# - As a value (what information do I carry?)
#
# Each projects from d_model → d_model, after which we split
# into multiple heads.
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# Output projection
#
# After attention is computed independently in each head,
# their outputs are concatenated and mixed using this layer.
# This allows information from different heads to interact.
self.W_o = nn.Linear(d_model, d_model)
# Scaled dot-product attention applied independently per head
self.attn = ScaledDotProductAttention(self.d_k)
def forward(self, Q, K, V, mask=None):
"""
Args:
Q, K, V: Input tensors of shape (batch_size, seq_len, d_model)
mask: Optional attention mask
Returns:
Output tensor of shape (batch_size, seq_len, d_model)
"""
B, T, _ = Q.shape
# Project inputs into Q, K, V spaces
# Shape after projection: (B, T, d_model)
#
# Then reshape to split into multiple heads:
# (B, T, n_heads, d_k)
#
# Finally transpose so that heads become a separate dimension:
# (B, n_heads, T, d_k)
#
# This allows attention to be computed independently per head.
Q = self.W_q(Q).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# Apply scaled dot-product attention to each head in parallel
# Output shape: (B, n_heads, T, d_k)
x, _ = self.attn(Q, K, V, mask)
# Recombine heads:
# (B, n_heads, T, d_k) → (B, T, n_heads * d_k)
#
# .contiguous() ensures memory layout is correct before reshaping
x = x.transpose(1, 2).contiguous().view(B, T, -1)
# Final linear projection mixes information across heads
return self.W_o(x)
To enable this, the input embeddings are first projected into separate query, key, and value spaces using learned linear transformations. These projections are then reshaped so that each head operates on a reduced dimensionality. Attention is computed independently within each head using scaled dot-product attention, producing multiple parallel context representations.
After attention is applied, the outputs of all heads are concatenated and passed through a final linear projection. This step is crucial: it allows information from different heads to be combined and re-integrated into a single representation, enabling interaction between the different relational perspectives learned by each head.
A key design choice is that splitting into multiple heads does not increase computational cost relative to a single large attention operation. The total dimensionality remains constant; it is simply redistributed across heads. This makes multi-head attention both expressive and efficient.
Overall, multi-head attention gives the Transformer the ability to look at the same sequence through multiple lenses at once, making it far more powerful at modeling complex structure, long-range dependencies, and contextual meaning.
FeedForward Network (Position-wise FFN)¶
1. Why do we need a FeedForward network at all?¶
Self-attention mixes information across tokens, but it is largely a linear operation with respect to feature dimensions. Without an additional non-linear component, stacking attention layers would still result in a model that is limited in its ability to learn complex transformations.
This motivation is stated explicitly in Vaswani et al., 2017 (Attention Is All You Need), where each attention sub-layer is followed by a position-wise fully connected feed-forward network to introduce non-linearity and increase representational power.
The FeedForward layer:
- Operates independently on each token (position-wise)
- Mixes information across features, not across time
- Acts like a learned feature transformer at each position
You can think of attention as answering “where should I look?” and the feedforward network as answering “how should I process what I found?”
This division of labor—global interaction via attention, local transformation via FFN—is a core architectural principle of Transformers.
Reference: Vaswani et al., Attention Is All You Need, NeurIPS 2017
2. Why expand to d_ff and then project back?¶
This is a deliberate capacity expansion strategy, introduced in the original Transformer.
In Vaswani et al. (2017), the feedforward network is defined as:
[ \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 ]
where ( W_1 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}} ) and ( W_2 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}} ).
The reasoning:
- Expanding to a higher-dimensional space (
d_ff ≫ d_model) allows the model to represent richer intermediate features. - Projecting back to
d_modelkeeps the interface between layers consistent, enabling deep stacking.
This mirrors a well-established pattern in deep learning:
temporary expansion → nonlinear processing → compression
Similar ideas appear in:
- Bottleneck layers in CNNs (He et al., ResNet)
- Inverted residuals in MobileNet
- MLP blocks in Vision Transformers (Dosovitskiy et al., 2020)
Reference: Vaswani et al., 2017 Dosovitskiy et al., An Image is Worth 16×16 Words, ICLR 2021
3. Why residual connections everywhere?¶
Residual connections were introduced to address the optimization difficulty of deep networks (He et al., 2016) and were adopted wholesale in Transformers.
In the Transformer, every sub-layer (attention and feedforward) is wrapped with a residual connection:
[ x \rightarrow x + \text{sublayer}(x) ]
This ensures that:
- The model can easily learn identity mappings
- Gradients flow effectively through many stacked layers
- Deeper models remain trainable and stable
Vaswani et al. explicitly note that residual connections are critical for training deep attention-based architectures.
Formally, residuals allow layers to learn incremental refinements rather than complete transformations, which significantly improves optimization.
Reference: He et al., Deep Residual Learning for Image Recognition, CVPR 2016 Vaswani et al., 2017
4. Why LayerNorm after each sub-layer?¶
Layer Normalization was chosen because:
- It normalizes per token, not across the batch
- It works well with variable sequence lengths
- It is stable under parallel computation
In the original Transformer, LayerNorm is applied after the residual connection (post-norm):
x = LayerNorm(x + sublayer(x))
LayerNorm helps:
- Reduce internal covariate shift
- Prevent exploding or vanishing activations
- Stabilize training across layers
Later work (e.g., Pre-LN Transformers) showed that moving LayerNorm before the sub-layer improves gradient flow in very deep models, but the core motivation remains unchanged.
References: Ba et al., Layer Normalization, arXiv 2016 Vaswani et al., 2017 Xiong et al., On Layer Normalization in Transformers, ICML 2020
5. Why split attention and feedforward into two sub-layers?¶
This separation is a deliberate inductive bias introduced in the original Transformer architecture.
Each Encoder block alternates between:
- Context aggregation (self-attention)
- Feature transformation (position-wise feedforward)
This enforces a clean conceptual separation:
- Attention decides which tokens influence each other
- Feedforward decides how each token representation should be transformed
Stacking these blocks allows:
- Global information to propagate through attention
- Local, non-linear refinement at each layer
This alternating structure is one of the key reasons Transformers scale so well with depth and data.
Reference: Vaswani et al., 2017
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
# Position-wise FeedForward Network
#
# This is applied independently to each token position.
# It does NOT mix information across time (sequence length),
# only across feature dimensions.
#
# The structure expands the representation into a higher-
# dimensional space (d_ff), applies a non-linearity,
# and then projects it back to d_model.
#
# This allows the model to:
# - Introduce non-linear transformations
# - Increase representational capacity
# - Learn complex feature interactions per token
self.net = nn.Sequential(
nn.Linear(d_model, d_ff), # Expansion (feature mixing)
nn.ReLU(), # Non-linearity
nn.Linear(d_ff, d_model) # Compression back to model dimension
)
def forward(self, x):
# x shape: (batch_size, seq_len, d_model)
# FeedForward is applied identically to every token position
return self.net(x)
class EncoderBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
# Multi-head self-attention:
# Allows each token to attend to all other tokens in the sequence
# and build context-aware representations.
self.attn = MultiHeadAttention(d_model, n_heads)
# Position-wise feedforward network:
# Adds non-linear transformation capacity after attention.
self.ff = FeedForward(d_model, d_ff)
# Layer normalization layers
#
# These stabilize training by normalizing feature distributions
# and help gradient flow in deep transformer stacks.
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Dropout for regularization
# Prevents overfitting and co-adaptation of neurons.
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: Input tensor of shape (batch_size, seq_len, d_model)
mask: Optional attention mask
Returns:
Output tensor of shape (batch_size, seq_len, d_model)
"""
# --- Self-Attention Sub-layer ---
#
# Residual connection:
# Preserves the original representation and improves gradient flow.
#
# Dropout:
# Regularizes attention outputs.
#
# LayerNorm (post-norm formulation here):
# Stabilizes activations and training dynamics.
x = self.norm1(
x + self.dropout(self.attn(x, x, x, mask))
)
# --- FeedForward Sub-layer ---
#
# Again, we apply:
# - Position-wise feedforward transformation
# - Dropout for regularization
# - Residual connection to preserve information
# - LayerNorm for stability
x = self.norm2(
x + self.dropout(self.ff(x))
)
return x
The Transformer encoder block combines ideas from residual learning, normalization, and MLP expansion into a modular design where attention handles interaction and feedforward layers handle transformation. This structure, first formalized in Attention Is All You Need, remains largely unchanged across modern large language models.
Decoder Block — Why it exists and how it works¶
1. Why does the decoder have masked self-attention?¶
Unlike the encoder, the decoder is used for autoregressive generation. When predicting token t, the model must not access tokens t+1, t+2, ….
This constraint is enforced using a causal (look-ahead) mask in self-attention, as introduced in Vaswani et al., 2017. The mask ensures that each position can only attend to earlier positions, preserving the correct probabilistic factorization.
Without masked self-attention, the decoder would leak future information and collapse training.
2. Why do we need cross-attention at all?¶
Cross-attention is what allows the decoder to condition on the source sequence.
- Queries come from the decoder (current generation state)
- Keys and values come from the encoder (encoded source information)
This mechanism allows the decoder to dynamically decide:
“Which parts of the input sequence are relevant right now?”
This design replaces fixed alignment mechanisms used in earlier sequence-to-sequence models (e.g., RNN encoder–decoders) with a fully differentiable, parallelizable alternative.
Reference: Bahdanau et al., Neural Machine Translation by Jointly Learning to Align and Translate, 2015 Vaswani et al., 2017
3. Why is self-attention applied before cross-attention?¶
This ordering is intentional and appears in the original Transformer.
The decoder first:
- Builds a representation of what has been generated so far (self-attention)
- Then aligns that representation with the source sequence (cross-attention)
Conceptually:
- Self-attention answers: “What have I already said?”
- Cross-attention answers: “What does the input tell me to say next?”
Reversing this order degrades performance, as the decoder would try to attend to the source without a coherent internal state.
4. Why does the decoder also need a FeedForward network?¶
Just like in the encoder, attention alone is insufficient.
- Attention mixes information across tokens
- FeedForward networks introduce non-linearity and feature transformation
In the decoder, the FFN refines the token representation after both:
- past context (self-attention), and
- source context (cross-attention)
This ensures expressive power at each decoding step.
5. Why three residual + normalization blocks?¶
Each sub-layer (self-attn, cross-attn, FFN) performs a fundamentally different operation and has different activation statistics. Wrapping each with its own residual connection and LayerNorm:
- Stabilizes training
- Preserves information flow
- Allows deep stacking of decoder blocks
This mirrors the encoder design but extends it to handle conditional generation.
class DecoderBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
# Masked multi-head self-attention
#
# This allows each target token to attend only to:
# - itself
# - previous target tokens
#
# The causal (look-ahead) mask ensures autoregressive behavior:
# the model cannot "see the future" during generation.
self.self_attn = MultiHeadAttention(d_model, n_heads)
# Cross-attention (encoder–decoder attention)
#
# Queries come from the decoder (what do I need?),
# Keys and Values come from the encoder (what information is available?).
#
# This is how the decoder conditions its predictions on the source sequence.
self.cross_attn = MultiHeadAttention(d_model, n_heads)
# Position-wise feedforward network
#
# Applies non-linear transformation independently to each target token
# after contextual information has been integrated.
self.ff = FeedForward(d_model, d_ff)
# LayerNorm layers for each sub-layer
#
# Separate norms are used because each sub-layer has
# different statistical properties.
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
# Dropout for regularization
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_out, tgt_mask=None, src_mask=None):
"""
Args:
x: Decoder input embeddings (batch_size, tgt_len, d_model)
enc_out: Encoder outputs (batch_size, src_len, d_model)
tgt_mask: Causal mask to block future target tokens
src_mask: Mask to block padding in the source sequence
Returns:
Decoder output representations
"""
# Allows each target token to attend only to earlier target tokens.
# This enforces the autoregressive property required for generation.
#
# Residual connection + dropout + LayerNorm stabilize training
# and preserve the original token representation.
x = self.norm1(
x + self.dropout(self.self_attn(x, x, x, tgt_mask))
)
# Decoder queries attend over encoder keys and values.
# This is the information bottleneck where the decoder
# decides which parts of the source sequence are relevant
# for predicting the next token.
x = self.norm2(
x + self.dropout(self.cross_attn(x, enc_out, enc_out, src_mask))
)
# Applies a non-linear, position-wise transformation
# to refine the decoder representations after all
# contextual information has been integrated.
x = self.norm3(
x + self.dropout(self.ff(x))
)
return x
The decoder first reasons about what it has already generated, then selectively consults the source sequence, and finally refines the result through non-linear transformation — all while strictly preventing future information leakage.
Transformer — Why this architecture works¶
1. Why separate encoder and decoder?¶
The Transformer is designed for sequence-to-sequence tasks (e.g., translation).
The encoder reads and understands the entire input sequence.
The decoder generates the output sequence one token at a time, conditioned on:
- what it has already generated, and
- the encoder’s representation of the input.
This separation cleanly mirrors the probabilistic goal:
Generate an output sequence conditioned on an input sequence.
2. Why stack multiple layers?¶
A single attention layer can only perform one round of information exchange.
By stacking layers:
- Early layers capture local or shallow relationships
- Later layers build more abstract, global representations
This hierarchical refinement is analogous to depth in CNNs and deep MLPs.
3. Why positional encoding at the input?¶
Attention alone has no notion of order. Without positional encodings:
["dog", "bites", "man"]
["man", "bites", "dog"]
would look identical to the model.
Adding positional encodings at the input ensures that order information flows through every layer of the network.
4. Why does the encoder run fully in parallel?¶
Unlike RNNs:
- There is no recurrence
- No dependency between time steps inside the encoder
This allows:
- Massive parallelism on GPUs/TPUs
- Faster training
- Better scaling to long sequences
This was one of the key motivations behind Attention Is All You Need.
Why project back to vocabulary with fc_out?¶
The model internally operates in a continuous vector space (d_model), but the final task is classification over discrete tokens.
The final linear layer:
- Converts hidden states into logits
- One logit per vocabulary token
- Enables training with cross-entropy loss
class Transformer(nn.Module):
def __init__(
self,
src_vocab,
tgt_vocab,
d_model=512,
n_heads=8,
d_ff=2048,
n_layers=6,
max_len=512
):
super().__init__()
# Source and target token embeddings
#
# These map discrete token IDs to continuous vectors of size d_model.
# Separate embeddings are used because source and target vocabularies
# may differ and play different roles in the model.
self.src_embed = nn.Embedding(src_vocab, d_model)
self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
# Positional encoding
#
# Adds explicit sequence order information to token embeddings.
# This is required because attention itself is permutation-invariant.
self.pos = PositionalEncoding(d_model, max_len)
# Encoder stack
#
# A stack of identical EncoderBlocks.
# Each block refines representations by:
# 1) allowing tokens to attend to each other (self-attention)
# 2) applying non-linear, position-wise transformations (FFN)
self.encoder = nn.ModuleList([
EncoderBlock(d_model, n_heads, d_ff)
for _ in range(n_layers)
])
# Decoder stack
#
# Each DecoderBlock:
# 1) builds an autoregressive representation of the target prefix
# 2) attends to the encoder outputs (cross-attention)
# 3) refines representations with a feedforward network
self.decoder = nn.ModuleList([
DecoderBlock(d_model, n_heads, d_ff)
for _ in range(n_layers)
])
# Final linear projection
#
# Maps decoder hidden states back to the target vocabulary space,
# producing logits for next-token prediction.
self.fc_out = nn.Linear(d_model, tgt_vocab)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
"""
Args:
src: Source token indices (batch_size, src_len)
tgt: Target token indices (batch_size, tgt_len)
src_mask: Mask for source padding tokens
tgt_mask: Causal mask for target tokens
Returns:
Logits over target vocabulary for each target position
"""
# --- Embedding + Positional Encoding ---
#
# Token embeddings provide semantic meaning.
# Positional encodings inject order information.
#
# Resulting shape: (batch_size, seq_len, d_model)
src = self.pos(self.src_embed(src))
tgt = self.pos(self.tgt_embed(tgt))
# --- Encoder ---
#
# The encoder processes the entire source sequence in parallel.
# Its output is a contextual representation of the source,
# which will be attended to by the decoder.
for layer in self.encoder:
src = layer(src, src_mask)
# --- Decoder ---
#
# The decoder processes the target sequence autoregressively.
# At each layer, it:
# - attends to past target tokens (masked self-attention)
# - attends to encoder outputs (cross-attention)
for layer in self.decoder:
tgt = layer(tgt, src, tgt_mask, src_mask)
# --- Output Projection ---
#
# Convert decoder representations into vocabulary logits.
# These logits are typically passed to a softmax during training
# or used for greedy/beam search during inference.
return self.fc_out(tgt)
The Transformer first builds a deep, contextual understanding of the input sequence, then generates the output sequence step by step—each time consulting both its past outputs and the encoded input—using only attention and feedforward layers
Causal Mask¶
1. What problem does the causal mask solve?¶
In autoregressive decoding, the model must obey:
When predicting token t, it can only use tokens
0 … t−1.
Without a causal mask, self-attention would allow every token to see the entire sequence, including future tokens, which would:
- Break the probabilistic factorization of sequence generation
- Cause information leakage during training
- Make inference behavior inconsistent with training
2. Why a lower-triangular matrix?¶
Consider a sequence of length 4:
tokens: y0 y1 y2 y3
The causal mask looks like:
1 0 0 0
1 1 0 0
1 1 1 0
1 1 1 1
Row i corresponds to query position i.
Column j corresponds to key position j.
A 1 means “attention allowed”, a 0 means “blocked”.
This structure enforces:
- Past → allowed
- Present → allowed
- Future → blocked
3. Why not just use zeros and ones directly?¶
The mask itself is not applied directly to the output. Instead, it is used inside attention as:
scores = scores.masked_fill(mask == 0, -∞)
This ensures that:
- Softmax assigns zero probability to future positions
- Gradients do not flow through masked connections
4. Why add two extra dimensions?¶
Attention scores have shape:
(batch_size, n_heads, query_len, key_len)
By returning a mask of shape:
(1, 1, size, size)
we allow PyTorch to broadcast the same causal structure across:
- all batches
- all attention heads
This ensures that causality is enforced globally and consistently.
5. Why is this essential even during training?¶
Even when the full target sequence is known (teacher forcing):
- The model must behave as if the future is unknown
- Otherwise, it learns shortcuts that do not exist at inference time
This alignment between training and inference is critical for stable generation.
def causal_mask(size):
# Create a square matrix of ones with shape (size, size)
# This represents all possible query–key positions.
mask = torch.ones(size, size)
# Keep only the lower-triangular part of the matrix.
#
# Positions above the diagonal correspond to "future" tokens
# (i.e., positions j > i when predicting token i).
#
# torch.tril enforces the causal constraint:
# token i can attend to tokens {0, ..., i}
# but NOT to tokens {i+1, ..., size-1}
mask = torch.tril(mask)
# Add two singleton dimensions so the mask can be broadcast
# across:
# - batch dimension
# - attention heads
#
# Final shape: (1, 1, size, size)
# This matches attention score tensors of shape:
# (batch_size, n_heads, seq_len, seq_len)
return mask.unsqueeze(0).unsqueeze(1)
The causal mask enforces the rule “no peeking into the future” by blocking attention to tokens that haven’t been generated yet.
model = Transformer(
src_vocab=10000,
tgt_vocab=10000,
d_model=512,
n_heads=8,
n_layers=6
)
src = torch.randint(0, 10000, (2, 20))
tgt = torch.randint(0, 10000, (2, 20))
mask = causal_mask(tgt.size(1))
out = model(src, tgt, tgt_mask=mask)
print(out.shape) # (batch, seq_len, vocab)
out
So... How do we go about generating Text¶
Transformers do not generate text, they generate next-token probability distributions, you decide how to sample from them.
Decoder-Only vs Encoder-Decoder¶
Some LLMs are encoder–decoder, but most modern general-purpose LLMs are decoder-only. The split is very clean once you classify them by what probability they model.
You should always start by asking: what probability distribution am I trying to learn?
For TinyStories, the goal is to model P(x₁, x₂, …, x_T) which factorizes as P(x₁) · P(x₂ | x₁) · P(x₃ | x₁, x₂) · …
This is exactly an autoregressive process: each token depends only on the tokens before it.
A decoder-only transformer is built precisely to model P(x_t | x₁ … x_{t−1}) using masked self-attention, which prevents the model from seeing future tokens. The architecture directly enforces the correct causal structure.
An encoder–decoder transformer models a different distribution: P(y | x)
That only makes sense when there is a separate input sequence x, such as:
- translation (French given English)
- summarization (summary given document)
- QA (answer given question)
TinyStories has no separate input. There is nothing to encode. Adding an encoder would give the model access to the entire story during training, which breaks the causal assumption and creates a mismatch between training and generation.
Decoder-only models also guarantee training–inference alignment:
- during training, the model only sees past tokens
- during generation, the model only sees past tokens
Encoder–decoder models cannot guarantee this alignment for language modeling without artificial constraints.
So the reason is not “because GPT does it”.
The real reason is:
Architecture follows probability factorization.
If you want to model P(x), use a decoder-only transformer. If you want to model P(y|x), use an encoder–decoder transformer.
Configuration¶
! curl "https://cas-bridge.xethub.hf.co/xet-bridge-us/645e8da96320b0efe40ade7a/02e40cc51c59a4bc6c51bd7bc9acda4316e208745be060558eaf500cd14e9f96?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20260112%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20260112T124634Z&X-Amz-Expires=3600&X-Amz-Signature=a948d01680e37e90762ec67ccadf0d597a9a94e89dd28da762376ed10ed60b41&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=6507eb42423b46492edf979c&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27TinyStoriesV2-GPT4-train.txt%3B+filename%3D%22TinyStoriesV2-GPT4-train.txt%22%3B&response-content-type=text%2Fplain&x-id=GetObject&Expires=1768225594&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2ODIyNTU5NH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2FzLWJyaWRnZS54ZXRodWIuaGYuY28veGV0LWJyaWRnZS11cy82NDVlOGRhOTYzMjBiMGVmZTQwYWRlN2EvMDJlNDBjYzUxYzU5YTRiYzZjNTFiZDdiYzlhY2RhNDMxNmUyMDg3NDViZTA2MDU1OGVhZjUwMGNkMTRlOWY5NioifV19&Signature=BPYTyJ2CTFy3Jb5MCHF4i%7EEAndHrAJP8bK6GP7tv6REvFarHH3eAur3dyE6w-7eo7PJGKzzeQDodxSYhwHQE95b2RuywZ5DlxTS%7EelkvlI52suIS6vgxa2bkGq5sW7zD1LAzuP3UEoJ1mniA7vq8WbQ2OPWKy%7ET87Zc8ieGiMZ7KoEOy4OpiUhY3SiU3e%7EI43nHwlcEQEGQ4VpRG5OlQNEnOSwSUhm4UlHIvz6gt3smSOCvgCV7l3MTY0CpPU00YwTp7w0NbIaHsTMuuk8N0XBVeAl%7EdE0o1qs3RSZeUeY2grJbVaxlevb657i0R%7E1uHDZ1%7E-ctEGcXRXpMcLqvlkw__&Key-Pair-Id=K2L8F4GPSG1IFC" -o TinyStoriesV2-GPT4-train.txt
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 2124M 100 2124M 0 0 128M 0 0:00:16 0:00:16 --:--:-- 66.8M
! head -c 100 TinyStoriesV2-GPT4-train.txt
Once upon a time there was a little boy named Ben. Ben loved to explore the world around him. He sa
import torch
DATA_PATH = "TinyStoriesV2-GPT4-train.txt" # change if needed
MAX_CHARS = 500_000 # limit for quick runs
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BLOCK_SIZE = 512
BATCH_SIZE = 128 # or 256 if it fits
GRAD_ACCUM = 1
D_MODEL = 256
N_HEADS = 4
N_LAYERS = 4
D_FF = 1024
LR = 3e-4
EPOCHS = 80
AMP = True
Dataset¶
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
LanguageModelDataset (TinyStories)¶
Historical context¶
Early language models (n-gram models, HMMs) modeled text as a probability distribution over sequences using the factorization:
The probability of a sentence is the product of the probability of each word given all previous words.
Neural language models (Bengio et al., 2003) adopted the same formulation, and this objective was later carried over to RNNs, LSTMs, and eventually Transformers.
When Transformer-based language models (Radford et al., 2018) removed recurrence entirely, they still preserved this exact probabilistic structure. The only requirement was a dataset that could generate many examples of:
“Given a sequence of previous tokens, predict the next token.”
This dataset class exists to construct that learning signal from raw tokenized text such as TinyStories.
Why sliding windows?¶
TinyStories is provided as a long stream of tokens, not pre-segmented into independent training samples.
The design choice here is to use a sliding window over the token stream:
- Each window provides a fixed-length context (
block_size) - The target is the same window shifted by one token
This idea predates Transformers and was used in:
- Feedforward neural language models
- RNN-based language models
- GPT-style models
It ensures:
- Efficient use of all tokens
- Strong local coherence learning
- Compatibility with causal self-attention
Why predict a sequence instead of a single token?¶
Although the objective is “predict the next token,” the model predicts the next token at every position simultaneously.
This aligns with:
- Parallel training (a core advantage of Transformers)
- Causal masking in self-attention
- Efficient GPU utilization
Each training example teaches the model multiple prediction tasks in one forward pass.
class LanguageModelDataset(Dataset):
def __init__(self, data, block_size):
# `data` is a long sequence of token IDs representing TinyStories.
# Example:
# [12, 45, 891, 34, 78, ...]
#
# `block_size` defines the length of the context window the model
# is allowed to condition on when predicting the next token.
self.data = data
self.block_size = block_size
def __len__(self):
# Each training example requires `block_size + 1` tokens:
# - `block_size` tokens for the input
# - 1 additional token for the shifted target
#
# We subtract `block_size` to avoid indexing past the end
# of the token sequence.
return len(self.data) - self.block_size
def __getitem__(self, idx):
# Input sequence (context window)
#
# This represents what the model is allowed to see.
# Shape: (block_size,)
x = self.data[idx : idx + self.block_size]
# Target sequence (next-token labels)
#
# This is the same sequence shifted one position to the right.
# For each position t, the model learns to predict y[t]
# given x[0:t].
y = self.data[idx + 1 : idx + self.block_size + 1]
return x, y
This dataset turns TinyStories into overlapping context–next-token prediction tasks, enforcing the same left-to-right learning objective that underpins all modern autoregressive language models.
Tokenizer¶
BPE (Byte Pair Encoding) Tokenization & Data Pipeline (TinyStories)¶
Why BPE tokenization?¶
This code uses Byte Pair Encoding (BPE) tokenization, a subword tokenization approach that balances vocabulary size and sequence length.
Historically:
BPE was introduced for text compression (Gage, 1994) and later adapted for NLP (Sennrich et al., 2016)
Modern language models (GPT, BERT, etc.) use subword tokenization like BPE to:
- handle rare and unknown words gracefully
- reduce vocabulary size compared to word-level tokenization
- maintain reasonable sequence lengths
BPE became popular in NLP through works like:
- Sennrich et al., 2016 – "Neural Machine Translation of Rare Words with Subword Units"
For TinyStories, BPE tokenization is a good choice because:
- It handles the full vocabulary efficiently
- Reduces sequence length compared to character-level tokenization
- Provides a more realistic setup similar to production language models
Why encode the entire corpus into one long tensor?¶
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
Instead of treating each story independently, the dataset is represented as a single continuous stream of tokens.
This aligns with:
- Autoregressive language modeling theory
- Sliding-window dataset construction
- GPT-style training pipelines
The model learns:
“Text is a continuous stream where any position can follow any other position.”
This assumption simplifies training and works well in practice.
Why use LanguageModelDataset + DataLoader?¶
This combination:
- Converts the token stream into overlapping
(x, y)training pairs - Enables batching, shuffling, and parallel loading
- Decouples data logic from model logic
Historically, this design emerged as PyTorch standardized:
Datasetfor data definitionDataLoaderfor efficient iteration
Why shuffle and drop last?¶
shuffle=True
drop_last=True
Shuffling prevents the model from seeing data in a fixed order, improving generalization
Dropping the last batch ensures consistent batch sizes, which simplifies:
- tensor shapes
- attention masking
- GPU efficiency
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
class BPETokenizer:
def __init__(self, text, vocab_size=1000):
self.vocab_size = vocab_size
# Start with character-level vocabulary
chars = sorted(list(set(text)))
self.vocab = chars.copy()
# Create initial merges (character pairs)
self.merges = {}
# Convert text to list of tokens (initially characters)
tokens = list(text)
# Learn BPE merges
for _ in range(vocab_size - len(chars)):
# Count frequency of each adjacent pair
pair_freqs = {}
for i in range(len(tokens) - 1):
pair = (tokens[i], tokens[i + 1])
pair_freqs[pair] = pair_freqs.get(pair, 0) + 1
if not pair_freqs:
break
# Find most frequent pair
best_pair = max(pair_freqs, key=pair_freqs.get)
# Create new token for this pair
new_token = ''.join(best_pair)
self.merges[best_pair] = new_token
self.vocab.append(new_token)
# Merge all occurrences of this pair in the token list
i = 0
while i < len(tokens) - 1:
if (tokens[i], tokens[i + 1]) == best_pair:
tokens[i] = new_token
del tokens[i + 1]
else:
i += 1
# Create mappings
self.stoi = {token: i for i, token in enumerate(self.vocab)}
self.itos = {i: token for token, i in self.stoi.items()}
def encode(self, s):
# Start with character-level encoding
tokens = list(s)
# Apply BPE merges greedily
changed = True
while changed:
changed = False
i = 0
while i < len(tokens) - 1:
pair = (tokens[i], tokens[i + 1])
if pair in self.merges:
tokens[i] = self.merges[pair]
del tokens[i + 1]
changed = True
else:
i += 1
return [self.stoi[token] for token in tokens]
def decode(self, ids):
# Convert IDs back to tokens
tokens = [self.itos[i] for i in ids]
return ''.join(tokens)
print("Loading data...")
with open(DATA_PATH, "r", encoding="utf-8") as f:
text = f.read()[:MAX_CHARS]
tokenizer = BPETokenizer(text, vocab_size=1000)
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
# ---------------- SPLIT ----------------
split_ratio = 0.9
split_idx = int(len(data) * split_ratio)
train_data = data[:split_idx]
val_data = data[split_idx:]
# ---------------- DATASETS ----------------
train_dataset = LanguageModelDataset(train_data, BLOCK_SIZE)
val_dataset = LanguageModelDataset(val_data, BLOCK_SIZE)
# ---------------- LOADERS ----------------
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
num_workers=3,
shuffle=True,
drop_last=True
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
num_workers=2,
shuffle=False, # IMPORTANT
drop_last=True
)
Loading data...
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 3 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(
This pipeline converts raw TinyStories text into a stream of BPE subword tokens and then into overlapping context–next-token prediction tasks, exactly matching the autoregressive objective used by GPT-style language models.
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len).unsqueeze(1)
div = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
GPT¶
MultiAttentionHead¶
Differences vs earlier attention (only real changes)¶
1. Self-attention only (no Q, K, V inputs)¶
Earlier
forward(self, Q, K, V, mask=None)
Now
forward(self, x, mask=None)
What changed
- Q, K, V are all derived from the same tensor
x - This is pure self-attention only
Why
- GPT-style decoder-only models never use cross-attention
- Simplifies the API and reduces surface area for bugs
- Matches how GPT actually operates
2. Attention logic is inlined (no ScaledDotProductAttention module)¶
Earlier
self.attn = ScaledDotProductAttention(self.d_k)
x, _ = self.attn(Q, K, V, mask)
Now
scores = (q @ k.transpose(-2, -1)) / sqrt(d_k)
attn = softmax(scores)
out = attn @ v
What changed
- Attention computation is written inline
- No separate abstraction
Why
- GPT-style minimalism
- Fewer function calls
- Easier to fuse / optimize later
- Matches most real GPT codebases
3. Causal masking assumed externally¶
Earlier
- Masking semantics varied (encoder mask, decoder mask, src/tgt)
Now
scores = scores.masked_fill(mask == 0, -1e9)
What changed
- This module assumes the mask is already causal and broadcastable
- No logic to construct or interpret masks internally
Why
- Cleaner separation of concerns
- GPT always uses causal masking
- Mask construction belongs to the model, not attention
4. No attention weights returned¶
Earlier
return output, attn
Now
return self.W_o(out)
What changed
- Attention weights are discarded
Why
- GPT training never uses attention maps
- Saves memory
- Reduces bandwidth and overhead
5. Single-path projection (GPT style)¶
This version:
- Projects Q/K/V from the same input
- Concatenates heads
- Applies one output projection
This matches GPT-1 → GPT-4 style attention exactly.
What did not change (important)¶
- Scaled dot-product attention
- Multi-head splitting
- Softmax over keys
- Residual compatibility
- Parallel attention computation
Only the scope and structure changed, not the math.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
"""
GPT-style multi-head self-attention
"""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# Fused projection for Q, K, V (faster + cleaner)
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
# Output projection
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
"""
Args:
x: (B, T, d_model)
mask: (1 or B, 1 or n_heads, T, T) causal mask (0 = block)
Returns:
(B, T, d_model)
"""
B, T, C = x.shape
# ------------------------------------------------
# Project & split into Q, K, V
# ------------------------------------------------
# (B, T, 3 * C) → (B, T, 3, n_heads, d_k)
qkv = self.qkv_proj(x).view(B, T, 3, self.n_heads, self.d_k)
# (B, n_heads, T, d_k)
q, k, v = qkv.unbind(dim=2)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# ------------------------------------------------
# Scaled dot-product attention
# ------------------------------------------------
# (B, n_heads, T, T)
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(scores, dim=-1)
# (B, n_heads, T, d_k)
out = attn @ v
# ------------------------------------------------
# Recombine heads
# ------------------------------------------------
# (B, T, C)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(out)
This version removes generality (cross-attention, modular abstractions) in favor of a minimal, self-attention–only implementation that matches how GPT actually works in practice.
Earlier Transformer (Vaswani et al., 2017)¶
Designed for sequence-to-sequence tasks (e.g., translation)
Architecture:
- Encoder stack
- Decoder stack
- Cross-attention between them
GPT (Radford et al., 2018 → GPT-4)¶
Designed for language modeling and generation
Architecture:
- Decoder-only Transformer
- No encoder
- No cross-attention
Why the change?
Language modeling only requires:
“Predict the next token given all previous tokens.”
There is no source sequence to condition on, so:
- Encoder becomes unnecessary
- Cross-attention is removed
- Model becomes simpler and more scalable
This simplification was one of the key insights behind GPT.
2. Architectural evolution: Pre-LN instead of Post-LN¶
Earlier code (Encoder / Decoder blocks)¶
we used post-norm:
x = LayerNorm(x + sublayer(x))
GPT-style block¶
This implementation uses pre-norm:
x = x + sublayer(LayerNorm(x))
Why this change?
Research showed that:
- Post-LN Transformers become unstable at depth
- Gradients struggle to flow through many layers
Pre-LN:
- Improves gradient flow
- Allows very deep models (100+ layers)
- Is now standard in GPT, LLaMA, PaLM, etc.
Key reference: Xiong et al., On Layer Normalization in Transformers, ICML 2020
3. Single attention type: masked self-attention only¶
Earlier:
- Encoder self-attention
- Decoder masked self-attention
- Decoder cross-attention
GPT:
- Only masked self-attention
self.attn = MultiHeadAttention(d_model, n_heads)
with:
mask = causal_mask(T)
Why?
- GPT models a single sequence
- Causality enforces correct generation
- Cross-attention is unnecessary
4. Weight tying (embedding ↔ output projection)¶
self.head.weight = self.embed.weight
This was not in the original Transformer, but introduced and popularized in:
- Press & Wolf (2017)
- GPT-1 onward
Why weight tying?
Reduces parameter count
Improves generalization
Enforces symmetry between:
- “reading” a token (embedding)
- “writing” a token (logits)
This is now standard practice.
5. Final LayerNorm before output¶
self.ln_f = nn.LayerNorm(D_MODEL)
This final normalization layer:
- Stabilizes logits
- Improves training dynamics
- Became standard in GPT-style models
Earlier Transformers normalized inside blocks only.
6. Positional encoding remains (but later evolved)¶
This code still uses sinusoidal positional encoding, but historically:
- GPT-1 / GPT-2 → learned positional embeddings
- GPT-3+ → RoPE / variants
- LLaMA → RoPE
- ALiBi → attention bias instead of embeddings
But the conceptual role remains unchanged:
Inject order into an otherwise permutation-invariant attention mechanism.
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
# Masked multi-head self-attention
# Uses causal masking to prevent access to future tokens
self.attn = MultiHeadAttention(d_model, n_heads)
# Position-wise feedforward network
self.ff = FeedForward(d_model, d_ff)
# Pre-layer normalization
# Improves gradient flow in deep networks
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x, mask):
# Pre-LN + residual connection for attention
# x ← x + Attention(LayerNorm(x))
x = x + self.attn(self.ln1(x), mask)
# Pre-LN + residual connection for feedforward
# x ← x + FFN(LayerNorm(x))
x = x + self.ff(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, vocab_size):
super().__init__()
# Token embedding layer
# Maps token IDs to continuous vectors
self.embed = nn.Embedding(vocab_size, D_MODEL)
# Positional encoding to inject order information
self.pos = PositionalEncoding(D_MODEL, BLOCK_SIZE)
# Stack of Transformer blocks
# Each block refines representations autoregressively
self.blocks = nn.ModuleList([
TransformerBlock(D_MODEL, N_HEADS, D_FF)
for _ in range(N_LAYERS)
])
# Final LayerNorm before output projection
self.ln_f = nn.LayerNorm(D_MODEL)
# Output projection to vocabulary space
self.head = nn.Linear(D_MODEL, vocab_size)
# Weight tying between input embeddings and output logits
# Reduces parameters and improves generalization
self.head.weight = self.embed.weight
def forward(self, x):
"""
Args:
x: Input token IDs (batch_size, seq_len)
Returns:
Logits over vocabulary for each position
"""
B, T = x.shape
# Create causal mask to enforce autoregressive constraint
mask = causal_mask(T).to(x.device)
# Token embeddings + positional encodings
x = self.embed(x)
x = self.pos(x)
# Apply stacked Transformer blocks
for block in self.blocks:
x = block(x, mask)
# Final normalization and projection to logits
x = self.ln_f(x)
return self.head(x)
GPT removes the encoder, enforces causality everywhere, stabilizes training with pre-norm, and scales depth—turning the Transformer into a pure next-token prediction machine.
Training¶
model = GPT(tokenizer.vocab_size).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
import torch, gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
print(torch.cuda.memory_summary())
|===========================================================================| | PyTorch CUDA memory summary, device ID 0 | |---------------------------------------------------------------------------| | CUDA OOMs: 0 | cudaMalloc retries: 0 | |===========================================================================| | Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed | |---------------------------------------------------------------------------| | Allocated memory | 13858 KiB | 13858 KiB | 13858 KiB | 0 B | | from large pool | 0 KiB | 0 KiB | 0 KiB | 0 B | | from small pool | 13858 KiB | 13858 KiB | 13858 KiB | 0 B | |---------------------------------------------------------------------------| | Active memory | 13858 KiB | 13858 KiB | 13858 KiB | 0 B | | from large pool | 0 KiB | 0 KiB | 0 KiB | 0 B | | from small pool | 13858 KiB | 13858 KiB | 13858 KiB | 0 B | |---------------------------------------------------------------------------| | Requested memory | 13857 KiB | 13857 KiB | 13857 KiB | 0 B | | from large pool | 0 KiB | 0 KiB | 0 KiB | 0 B | | from small pool | 13857 KiB | 13857 KiB | 13857 KiB | 0 B | |---------------------------------------------------------------------------| | GPU reserved memory | 14336 KiB | 14336 KiB | 14336 KiB | 0 B | | from large pool | 0 KiB | 0 KiB | 0 KiB | 0 B | | from small pool | 14336 KiB | 14336 KiB | 14336 KiB | 0 B | |---------------------------------------------------------------------------| | Non-releasable memory | 489472 B | 1816 KiB | 7704 KiB | 7226 KiB | | from large pool | 0 B | 0 KiB | 0 KiB | 0 KiB | | from small pool | 489472 B | 1816 KiB | 7704 KiB | 7226 KiB | |---------------------------------------------------------------------------| | Allocations | 53 | 53 | 53 | 0 | | from large pool | 0 | 0 | 0 | 0 | | from small pool | 53 | 53 | 53 | 0 | |---------------------------------------------------------------------------| | Active allocs | 53 | 53 | 53 | 0 | | from large pool | 0 | 0 | 0 | 0 | | from small pool | 53 | 53 | 53 | 0 | |---------------------------------------------------------------------------| | GPU reserved segments | 7 | 7 | 7 | 0 | | from large pool | 0 | 0 | 0 | 0 | | from small pool | 7 | 7 | 7 | 0 | |---------------------------------------------------------------------------| | Non-releasable allocs | 2 | 4 | 7 | 5 | | from large pool | 0 | 0 | 0 | 0 | | from small pool | 2 | 4 | 7 | 5 | |---------------------------------------------------------------------------| | Oversize allocations | 0 | 0 | 0 | 0 | |---------------------------------------------------------------------------| | Oversize GPU segments | 0 | 0 | 0 | 0 | |===========================================================================|
import torch
from torch.cuda.amp import autocast, GradScaler
# Compile model (PyTorch 2.x)
model = torch.compile(model)
scaler = GradScaler()
ACC_STEPS = 4 # 🔧 tune this
CLIP_NORM = 1.0
/tmp/ipython-input-1351142244.py:7: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
scaler = GradScaler()
@torch.no_grad()
def generate(prompt, max_new_tokens=300, temperature=0.8):
model.eval()
ids = torch.tensor([tokenizer.encode(prompt)], device=DEVICE)
for _ in range(max_new_tokens):
# Truncate input to BLOCK_SIZE if it exceeds it.
# The model was trained with BLOCK_SIZE context.
input_ids = ids[:, -BLOCK_SIZE:]
logits = model(input_ids)
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, 1)
ids = torch.cat([ids, next_id], dim=1)
return tokenizer.decode(ids[0].tolist())
from tqdm import tqdm
import torch
import math
import os
from torch.cuda.amp import autocast, GradScaler
# ============================================================
# RESUME FROM CHECKPOINT (FULL STATE)
# ============================================================
best_val_loss = math.inf
start_epoch = 0
scaler = GradScaler()
if os.path.exists("best_model.pt"):
print("Resuming from best_model.pt")
ckpt = torch.load("best_model.pt", map_location=DEVICE)
if isinstance(ckpt, dict) and "model" in ckpt:
# ✅ NEW-style checkpoint
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
scaler.load_state_dict(ckpt["scaler"])
best_val_loss = ckpt["best_val_loss"]
start_epoch = ckpt["epoch"] + 1
else:
# ✅ OLD-style checkpoint (weights only)
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (optimizer/scaler reset)")
model.to(DEVICE)
# ============================================================
# TRAINING LOOP
# ============================================================
for epoch in range(start_epoch, EPOCHS):
# ===================== TRAIN =====================
model.train()
optimizer.zero_grad(set_to_none=True)
train_loss = 0.0
for step, (x, y) in enumerate(tqdm(train_loader, desc=f"Train {epoch+1}")):
x = x.to(DEVICE, non_blocking=True)
y = y.to(DEVICE, non_blocking=True)
with autocast():
logits = model(x)
loss = criterion(
logits.view(-1, logits.size(-1)),
y.view(-1)
)
loss = loss / ACC_STEPS
scaler.scale(loss).backward()
train_loss += loss.item() * ACC_STEPS
should_step = (
(step + 1) % ACC_STEPS == 0
or (step + 1) == len(train_loader) # flush last batch
)
if should_step:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
train_loss /= len(train_loader)
# ===================== VALIDATE =====================
model.eval()
val_loss = 0.0
MAX_VAL_BATCHES = 200
with torch.no_grad(), autocast():
for i, (x, y) in enumerate(val_loader):
if i >= MAX_VAL_BATCHES:
break
x = x.to(DEVICE, non_blocking=True)
y = y.to(DEVICE, non_blocking=True)
logits = model(x)
loss = criterion(
logits.view(-1, logits.size(-1)),
y.view(-1)
)
val_loss += loss.item()
val_loss /= min(len(val_loader), MAX_VAL_BATCHES)
# ===================== LOGGING + CHECKPOINT =====================
improved = val_loss < best_val_loss
if improved:
best_val_loss = val_loss
torch.save(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scaler": scaler.state_dict(),
"epoch": epoch,
"best_val_loss": best_val_loss,
},
"best_model.pt",
)
print(
f"Epoch {epoch+1:3d} | "
f"Train {train_loss:.4f} | "
f"Val {val_loss:.4f}"
+ (" ✓" if improved else "")
)
# ===================== SAMPLE =====================
if (epoch + 1) % 10 == 0:
model.eval()
prompt = torch.tensor([[tokenizer.stoi.get("O", 0)]], device=DEVICE)
gen = model.generate(prompt, max_tokens=100)
print(f"\nSample:\n{tokenizer.decode(gen[0].tolist())}\n")
# ============================================================
# DONE
# ============================================================
print("\n" + "=" * 50)
print(f"Training complete! Best val loss: {best_val_loss:.4f}")
print("=" * 50)
/tmp/ipython-input-960699167.py:13: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
scaler = GradScaler()
Resuming from best_model.pt Loaded weights-only checkpoint (optimizer/scaler reset)
Train 1: 0%| | 0/1539 [00:00<?, ?it/s]/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 3 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
warnings.warn(
/tmp/ipython-input-960699167.py:50: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with autocast():
Train 1: 100%|██████████| 1539/1539 [04:39<00:00, 5.51it/s]
/tmp/ipython-input-960699167.py:81: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.no_grad(), autocast():
Epoch 1 | Train 2.9840 | Val 2.8299 ✓
Train 2: 100%|██████████| 1539/1539 [04:47<00:00, 5.36it/s]
Epoch 2 | Train 2.0371 | Val 2.9852
Train 3: 100%|██████████| 1539/1539 [04:46<00:00, 5.37it/s]
Epoch 3 | Train 1.2778 | Val 4.0615
Train 4: 100%|██████████| 1539/1539 [04:46<00:00, 5.38it/s]
Epoch 4 | Train 0.7066 | Val 5.6822
Train 5: 100%|██████████| 1539/1539 [04:45<00:00, 5.40it/s]
Epoch 5 | Train 0.3048 | Val 7.1991
Train 6: 5%|▍ | 74/1539 [00:13<04:36, 5.31it/s]
Generation¶
Drumroll ....¶
print("\n--- GENERATED STORY ---\n")
print(generate("Once upon a time "))
--- GENERATED STORY --- Once upon a time was tirered. It made go to the park, but he also lion named sed sed with a new friend, a little girl named Lily him. She saw her frog and opped her head. Lily was very munice. Then, She camemmembered Mom comineace braceful was not too too happy ve fun. <|endoftext|> Once upon a time, there was a little girl named Mia. Mia had a big, red in a little house with her always smiled. Mia truckly little cat the ad. Lily lived in a small house with her momy. Mia loved to stay very happy. Mia went outside to play in the stawn and saw a lotched the slide. She was very happy and s. They cyedguit. The slizzlep. Mia could not find it was very happy. The kids girl
A 3.1M-parameter model is very small by modern language-model standards. For context, classic “toy” GPT implementations used for learning typically range from 1–10M parameters, which is enough to capture basic token statistics and short-range patterns but not sustained coherence. GPT-2 Small already has 117M parameters (≈40× larger), GPT-2 Medium 345M, GPT-2 Large 774M, and GPT-2 XL 1.5B. Modern production LLMs operate in the tens to hundreds of billions of parameters. Practically, this means a 3.1M model is expected to produce locally sensible text but struggle with long-term consistency, reasoning, and abstraction
This is ideal for understanding how transformers work, not for demonstrating strong language capability.