Bridging GRPO and Transformer Learning Mechanisms to Enhance Language Model Training

| January 22, 2025

Bridging GRPO and Transformer Learning Mechanisms to Enhance Language Model Training

In this post, we will explore how GRPO (Group Relative Policy Optimization) naturally extends the standard Transformer training paradigm (which is typically pure next-token prediction using maximum likelihood) into a reinforcement learning (RL) framework. We will derive the connections step-by-step and discuss how these insights can guide future improvements.


Review

1. Why Naive Transformer Can learn

We’ll review how all these elements combine so that the Transformer learns

pθ(xtx<t),p_\theta(x_t \mid x_{<t}),

and thus factorizes

pθ(x)  =  t=1Npθ(xtx<t)p_\theta(\mathbf{x}) \;=\; \prod_{t=1}^N p_\theta(x_t \mid x_{<t})

for a text sequence x=(x1,,xN)\mathbf{x}=(x_1,\ldots,x_N). Finally, we connect it to the negative log-likelihood (NLL) objective.


1.1 Causal Masking

In a causal language model, position tt must only attend to previous positions 1,,t11,\dots,t-1. To enforce this in self-attention, we define a mask M\mathbf{M} such that

Mt,s  =  {0,st,,s>t.\mathbf{M}_{t,s} \;=\; \begin{cases} 0, & s \le t,\\ -\infty, & s > t. \end{cases}

For each self-attention head, the attention weights αt,:\boldsymbol{\alpha}_{t,:} are given by a softmax over the scaled dot products plus the mask:

αt,s  =  softmax ⁣(QtKsdk  +  Mt,s),\alpha_{t,s} \;=\; \mathrm{softmax}\!\Bigl( \frac{\mathbf{Q}_t \cdot \mathbf{K}_s}{\sqrt{d_k}} \;+\; \mathbf{M}_{t,s} \Bigr),

where αt,s\alpha_{t,s} is the attention weight of position tt on position ss. Because Mt,s=\mathbf{M}_{t,s} = -\infty for s>ts>t, those positions contribute zero after the softmax—enforcing causality.


1.2 Query, Key, Value Transformations

Each multi-head self-attention block creates queries, keys, and values via learned linear projections of the hidden states zs\mathbf{z}_s. For head hh (out of HH total heads), we define:

Qs(h)  =  WQ(h)zs,Ks(h)  =  WK(h)zs,Vs(h)  =  WV(h)zs.\mathbf{Q}_s^{(h)} \;=\; \mathbf{W}_Q^{(h)} \,\mathbf{z}_s, \quad \mathbf{K}_s^{(h)} \;=\; \mathbf{W}_K^{(h)} \,\mathbf{z}_s, \quad \mathbf{V}_s^{(h)} \;=\; \mathbf{W}_V^{(h)} \,\mathbf{z}_s.
  • WQ(h),WK(h),WV(h)\mathbf{W}_Q^{(h)},\mathbf{W}_K^{(h)},\mathbf{W}_V^{(h)} each have dimension dmodel×dk\mathrm{d}_\text{model}\times \mathrm{d}_k (or dv\mathrm{d}_v), where typically dk=dv=dmodel/H\mathrm{d}_k = \mathrm{d}_v = \mathrm{d}_\text{model}/H.

For position tt, the attention weights w.r.t. position sts\le t become:

αt,s(h)  =  softmax ⁣(Qt(h)Ks(h)dk+Mt,s).\alpha_{t,s}^{(h)} \;=\; \mathrm{softmax}\!\Bigl( \frac{\mathbf{Q}_t^{(h)} \cdot \mathbf{K}_s^{(h)}}{\sqrt{d_k}} + \mathbf{M}_{t,s} \Bigr).

The context vector for head hh at position tt is then

ct(h)  =  s=1Nαt,s(h)  Vs(h),\mathbf{c}_t^{(h)} \;=\; \sum_{s=1}^{N} \alpha_{t,s}^{(h)} \;\mathbf{V}_s^{(h)},

noting that αt,s(h)=0\alpha_{t,s}^{(h)}=0 for s>ts>t via the causal mask.

2.1. Multi-Head Concatenation and Linear

If there are HH heads, we concatenate their outputs:

ct  =  Concat ⁣[ct(1),  ct(2),  ,  ct(H)],\mathbf{c}_t \;=\; \mathrm{Concat}\!\bigl[ \mathbf{c}_t^{(1)},\;\mathbf{c}_t^{(2)},\;\dots,\;\mathbf{c}_t^{(H)} \bigr],

which is dimension dmodel=Hdk\mathrm{d}_\text{model} = H\,\mathrm{d}_k. We then apply another linear projection

ut  =  WOct  +  bO,\mathbf{u}_t \;=\; \mathbf{W}_O\,\mathbf{c}_t \;+\; \mathbf{b}_O,

where WO\mathbf{W}_O has shape dmodel×dmodel\mathrm{d}_\text{model}\times \mathrm{d}_\text{model}. The residual connection is typically added next (plus layer normalization—discussed below).


1.3 Linear Transformations

A linear layer in a Transformer can be written:

Linear(x)  =  Wx  +  b,\mathrm{Linear}(\mathbf{x}) \;=\; \mathbf{W}\,\mathbf{x} \;+\; \mathbf{b},

where WRdout×din\mathbf{W}\in\mathbb{R}^{d_\text{out}\times d_\text{in}} and bRdout\mathbf{b}\in \mathbb{R}^{d_\text{out}}. In the multi-head attention context, we have three such linear layers (WQ,WK,WV\mathbf{W}_Q,\mathbf{W}_K,\mathbf{W}_V) to form queries, keys, and values, plus a final one (WO\mathbf{W}_O) to map the concatenated heads back to dmodel\mathrm{d}_\text{model}.

Additionally, a typical feed-forward sub-layer within the Transformer is also a composition of two linear layers with a nonlinearity σ()\sigma(\cdot) (often ReLU or GELU):

FFN(x)  =  W2σ ⁣(W1x+b1)  +  b2,\mathrm{FFN}(\mathbf{x}) \;=\; \mathbf{W}_2 \,\sigma\!\bigl(\mathbf{W}_1\,\mathbf{x} + \mathbf{b}_1 \bigr) \;+\; \mathbf{b}_2,

where W1Rdff×dmodel\mathbf{W}_1\in \mathbb{R}^{d_\text{ff}\times d_\text{model}}, W2Rdmodel×dff\mathbf{W}_2\in \mathbb{R}^{d_\text{model}\times d_\text{ff}}.


1.4 Layer Normalization

LayerNorm is applied either before or after the main sub-layer operations (depending on the Transformer variant). The standard formula for a hidden vector z=(z1,,zd)Rd\mathbf{z} = (z_1,\dots,z_d)\in\mathbb{R}^d is:

LN(z)  =  z    μ(z)σ2(z)+ε    γ  +  β,\mathrm{LN}(\mathbf{z}) \;=\; \frac{\mathbf{z}\;-\;\mu(\mathbf{z})} {\sqrt{\sigma^2(\mathbf{z}) + \varepsilon}} \;\odot\;\boldsymbol{\gamma} \;+\; \boldsymbol{\beta},

where

  1. μ(z)\mu(\mathbf{z}) is the mean of the components (z1,,zd)(z_1,\dots,z_d),
  2. σ2(z)\sigma^2(\mathbf{z}) is the variance of (z1,,zd)(z_1,\dots,z_d),
  3. γ\boldsymbol{\gamma} and β\boldsymbol{\beta} are learnable parameters of dimension dd
  4. ε\varepsilon is a small constant (e.g., 10510^{-5}) to avoid division by zero,
  5. \odot denotes elementwise multiplication.

Thus, LayerNorm re-centers and re-scales each dimension of z\mathbf{z}, per sample, improving training stability.


1.5 Auto-Regressive Next-Token Probability

After LL layers of (1) self-attention with causal masking, (2) feed-forward layers, and (3) layer normalization + residual connections, we obtain a final hidden state ht\mathbf{h}_t for each position tt. We then project ht\mathbf{h}_t to vocabulary logits:

ot  =  Wvocabht  +  bvocab,\mathbf{o}_t \;=\; \mathbf{W}_\text{vocab}\,\mathbf{h}_t \;+\; \mathbf{b}_\text{vocab},

where Wvocab\mathbf{W}_\text{vocab} has size V×dmodel\lvert\mathcal{V}\rvert\times d_\text{model} (assuming a vocabulary V\mathcal{V} of size V\lvert \mathcal{V}\rvert) and bvocabRV\mathbf{b}_\text{vocab}\in\mathbb{R}^{\lvert \mathcal{V}\rvert}. A softmax over these logits yields:

pθ(xt=wx<t)  =  exp(ot[w])uVexp(ot[u]).p_\theta(x_t=w \mid x_{<t}) \;=\; \frac{\exp(\mathbf{o}_t[w])} {\sum_{u\in \mathcal{V}} \exp(\mathbf{o}_t[u])}.

This defines the auto-regressive distribution for the next token.


1.6 Negative Log-Likelihood Objective

With the entire network parameterized by θ\theta, the language modeling likelihood of an observed sequence x=(x1,,xN)\mathbf{x} = (x_1,\dots,x_N) factorizes as:

pθ(x)  =  t=1Npθ(xtx<t).p_\theta(\mathbf{x}) \;=\; \prod_{t=1}^{N} p_\theta\bigl(x_t \mid x_{<t}\bigr).

The training objective is to minimize the negative log-likelihood over a large corpus D\mathcal{D}:

LNLL(θ)  =  xD  t=1Nlogpθ(xtx<t),\mathcal{L}_\text{NLL}(\theta) \;=\; -\,\sum_{\mathbf{x}\in \mathcal{D}} \;\sum_{t=1}^N \log\,p_\theta(x_t \mid x_{<t}),

which is equivalent to cross-entropy loss. Concretely, for each token xtx_t, the cross-entropy w.r.t. the model’s predicted distribution is

logpθ(xtx<t)  =  ot[xt]  +  log ⁣(uVexp(ot[u])),-\log\,p_\theta(x_t \mid x_{<t}) \;=\; -\mathbf{o}_t[x_t] \;+\; \log\!\Bigl( \sum_{u\in \mathcal{V}} \exp(\mathbf{o}_t[u]) \Bigr),

and backpropagation through the Q,K,V transformations, the linear layers, and the layer normalization updates all parameters θ\theta so as to increase the likelihood of the observed tokens.


1.7 Final Takeaway

  • Causal Masking:

    Mt,s={0,st,otherwise\mathbf{M}_{t,s} = \begin{cases} 0, & s\le t\\ -\infty, & \text{otherwise} \end{cases}

    ensures that position tt never sees positions s>ts>t.

  • Self-Attention: Uses

    αt,s=softmax(QtKsdk+Mt,s),ct=s=1Nαt,sVs,\alpha_{t,s} = \mathrm{softmax} \Bigl( \frac{\mathbf{Q}_t \cdot \mathbf{K}_s}{\sqrt{d_k}} + \mathbf{M}_{t,s} \Bigr), \quad \mathbf{c}_t = \sum_{s=1}^{N} \alpha_{t,s}\,\mathbf{V}_s,

    to selectively integrate information from previous tokens.

  • Linear Layers:

    Linear(x)=Wx+b,\mathrm{Linear}(\mathbf{x}) = \mathbf{W}\mathbf{x} + \mathbf{b},

    transform the embeddings, attention outputs, and feed-forward transformations.

  • LayerNorm:

    LN(z)=zμ(z)σ2(z)+εγ+β.\mathrm{LN}(\mathbf{z}) = \frac{\mathbf{z} - \mu(\mathbf{z})} {\sqrt{\sigma^2(\mathbf{z}) + \varepsilon}} \odot\boldsymbol{\gamma} + \boldsymbol{\beta}.
  • Vocabulary Projection: A final linear layer Wvocab\mathbf{W}_\text{vocab} maps the final hidden state ht\mathbf{h}_t to vocabulary logits ot\mathbf{o}_t, from which we apply softmax.

  • Training: Minimizing

    LNLL(θ)  =  xDt=1Nlogpθ(xtx<t)\mathcal{L}_\text{NLL}(\theta) \;=\; -\sum_{\mathbf{x}\in \mathcal{D}}\sum_{t=1}^N \log\,p_\theta(x_t \mid x_{<t})

    aligns the model parameters θ\theta to make accurate next-token predictions, purely via supervised next-token prediction.

  • Key: Minimizing this Negative Log-Likelihood (NLL) via cross-entropy is equivalent to maximizing the likelihood of the training data. This procedure uses purely supervised learning with no environment-based “reward.”


Reinforcement Learning in DeepSeek R1

2. Enter GRPO: An RL Overlay on Transformers

TL;DR What’s the difference?

When we switch from pure next-token prediction to a reinforcement learning perspective, we view each generated token (or entire sequence) as an “action” that yields a reward. The policy πθ\pi_\theta is still a Transformer, but its training signal changes from simple log-likelihood to expected reward.

2.1. Group Relative Policy Optimization in a Nutshell

  • Group Baseline: Instead of learning a value function as in PPO, GRPO uses a group-level baseline for variance reduction. For a group of kk outputs {a1,a2,,ak}\{ a^1, a^2, \dots, a^k \} from the same prompt X\mathcal{X}, it computes the group-average reward

    b  =  1ki=1kri,b \;=\; \frac{1}{k}\,\sum_{i=1}^{k} r^i,

    where rir^i is the reward of the ii-th output.

  • Group-Relative Advantage:

    Ai  =  ri    b.A^i \;=\; r^i \;-\; b.

    This advantage replaces the role of a learned V()V(\cdot) function in PPO.

  • Policy Ratio: As with PPO, GRPO measures the ratio of new policy to old policy

    ri(θ)  =  πθ(aiX)πθold(aiX).r_i(\theta) \;=\; \frac{\pi_\theta(a^i \mid \mathcal{X})}{\pi_{\theta_\mathrm{old}}(a^i \mid \mathcal{X})}.
  • KL Divergence or Clipping: GRPO can control policy updates via a KL penalty or by clipping this ratio in a trust region, similar to PPO.

2.2. Surrogate Objective with Group Averages

If we adopt a clipping-based approach (akin to PPO), the GRPO objective might look like:

LGRPO(θ)  =  EX,a ⁣[min(ri(θ)Ai,  clip(ri(θ),1ϵ,1+ϵ)Ai)].L^\mathrm{GRPO}(\theta) \;=\; \mathbb{E}_{\mathcal{X},\,a}\! \Bigl[ \min\bigl(r_i(\theta)\,A^i,\; \mathrm{clip}\bigl(r_i(\theta),\,1-\epsilon,\,1+\epsilon\bigr)\,A^i \bigr) \Bigr].

Alternatively, if we incorporate a direct KL penalty w.r.t. a reference policy πref\pi_{\mathrm{ref}}, we might write:

LGRPO(θ)  =  EX,a ⁣[ri(θ)Ai    βDKL ⁣(πθ(X)    πref(X))].L^\mathrm{GRPO}(\theta) \;=\; \mathbb{E}_{\mathcal{X},\,a}\! \Bigl[ r_i(\theta)\,A^i \;-\; \beta\,D_\mathrm{KL}\!\Bigl(\pi_\theta(\cdot\mid\mathcal{X})\;\big\|\;\pi_{\mathrm{ref}}(\cdot\mid\mathcal{X})\Bigr) \Bigr].

In both cases, the key difference from pure cross-entropy is that we are directly maximizing a reward-based objective rather than maximizing likelihood of tokens in a dataset.


3. Where They Intersect: The Transformer as a Policy Network

Underneath both pure next-token MLE and GRPO (or PPO), the neural architecture is typically the same Transformer. What changes is the loss function and associated training data:

  1. Pure MLE: The training data is a static corpus of text, and the objective is to predict the next token—no explicit reward.

  2. GRPO / PPO:

    • We still use the Transformer architecture for πθ(aX)\pi_\theta(a \mid \mathcal{X}).
    • We gather data (policy outputs) and compute or receive a reward signal rir^i.
    • We update θ\theta to improve this reward-based objective.

Thus, the mechanism of self-attention, feed-forward layers, and so on does not change. Instead, the training objective shifts from “match the next token distribution in the dataset” to “output sequences with high reward.”


4. Potential Improvements: Unifying Likelihood and Reward Signals

A core tension arises in practical LLM training: we want linguistic fluency (which is well-captured by next-token prediction on large corpora) but also want task-specific behaviors (captured by a reward model or user feedback). Some ways forward:

  1. Two-Stage Training:

    • First, pre-train the Transformer with standard cross-entropy on a large text corpus (for broad language fluency).
    • Then, fine-tune with GRPO on a smaller dataset or a reward signal. This approach is used in many RLHF (Reinforcement Learning from Human Feedback) pipelines.
  2. Hybrid Objective:
    Combine the likelihood term and the reward term in a single objective—e.g.:

    L(θ)  =  λMLE[logpθ(x)]  +  λRL[rKLDKL(πθπref)  +  ].L(\theta) \;=\; \lambda_\mathrm{MLE}\, \Bigl[ -\,\log p_\theta(x) \Bigr] \;+\; \lambda_\mathrm{RL}\, \Bigl[ -\,r_\mathrm{KL}\,D_\mathrm{KL}\bigl(\pi_\theta\|\pi_{\mathrm{ref}}\bigr) \;+\; \dots \Bigr].

    Tuning λMLE\lambda_\mathrm{MLE} vs. λRL\lambda_\mathrm{RL} balances linguistic correctness with reward maximization.

  3. Better Baselines:

    • Instead of using a simple group average or a learned value function, one could integrate more sophisticated baselines (e.g., learned critics that exploit contextual cues).
    • Or incorporate group-level variance reduction plus longer-horizon estimates (GAE-like expansions) for tasks requiring multiple steps.
  4. Dynamic Reference Policies:

    • Periodically update πref\pi_{\mathrm{ref}} to be the current πθ\pi_\theta.
    • Use an adaptive schedule for the KL penalty coefficient so that the policy can explore initially but is later constrained when it’s sufficiently good.

5. Mathematical Summary: The Relation Between MLE and GRPO

We can compare the respective training objectives more formally:

  1. Standard Transformer (MLE)

    minθ    [xDlogpθ(x)]  =  minθ    LNLL(θ).\min_\theta \;\; \Bigl[ -\sum_{\mathbf{x} \in \mathcal{D}} \log p_\theta(\mathbf{x}) \Bigr] \;=\; \min_\theta \;\; \mathcal{L}_\text{NLL}(\theta).
    • Data x\mathbf{x} is from a static corpus.
    • The “policy” πθ\pi_\theta is trained to predict tokens accurately.
  2. GRPO Objective

    maxθ    ExX,aπθ[Ai(ri(θ))    βDKL(πθπref)],\max_\theta \;\; \mathbb{E}_{\mathbf{x}\sim \mathcal{X},\,a\sim\pi_\theta} \Bigl[ A^i\,(r_i(\theta)) \;-\; \beta\,D_{\mathrm{KL}}\bigl(\pi_\theta \,\|\, \pi_\mathrm{ref}\bigr) \Bigr],
    • Data is generated on the fly by sampling from πθ\pi_\theta.
    • Rewards {ri}\{r^i\} can come from a reward model or human feedback.
    • The group-based baseline bb reduces variance.

Bridging the Two Objectives

In many real-world applications (e.g., RLHF), the pipeline is:

  1. Pre-train on MLE (huge corpus).
  2. Reward model is trained (from human preferences or other signals).
  3. Fine-tune with an RL method like GRPO or PPO to align the model with desired behavior.

This approach leverages both the linguistic understanding from MLE and the targeted reward optimization from RL.


6. Concluding Thoughts

  • GRPO can be viewed as a natural RL extension of a Transformer that otherwise does purely next-token prediction.
  • Transformer architecture remains the same—what changes is the objective and the data generation process (from static to on-policy).
  • Hybrid / multi-stage training can preserve fluency while encouraging the model to generate high-value responses in certain tasks.
  • Future improvements might focus on more efficient baselines, better reference policy management, and smoother transitions between MLE and RL phases.

Takeaway: By understanding both the rigorous foundation of maximum likelihood Transformer training and how GRPO modifies it into a reward-driven RL scheme, we can better tailor language models to produce desirable outputs beyond mere likelihood matching. The synergy of large-scale pre-training and reward-guided fine-tuning is likely to remain a core strategy for building advanced, aligned language models.


Further Reading & References

  • “Attention Is All You Need” by Vaswani et al., 2017, for the original Transformer.
  • “Proximal Policy Optimization” by Schulman et al., 2017, for PPO.
  • “Learning to Summarize with Human Feedback” by Stiennon et al., 2020, for a practical RLHF application.
  • “Group Relative Policy Optimization (GRPO)” for simplifying PPO’s baseline with group-level statistics.