I read the paper and code do you dont have to: Self-Adaptive Large Language Models

| January 18, 2025

I read the paper and code do you dont have to: Self-Adaptive Large Language Models

Table of Contents


The work described in this blog post is based on the research and code from SakanaAI’s Transformer²: Self-adaptive LLMs project. The paper presents innovative methods for dynamically adapting large language models (LLMs) to new tasks, and the accompanying codebase implements these concepts in practice.

Paper: Transformer²: Self-Adaptive LLMs

Code: SakanaAI’s Transformer² GitHub Repository

Blog: SakanaAI’s Transformer² Blog Post


1. The Mathematical Backbone: SVF, RL, and Weighted Combinations

Singular Value Fine-tuning (SVF)

SVF leverages Singular Value Decomposition (SVD) to fine-tune LLMs efficiently:

  • SVD of a Weight Matrix:

    W=UΣVTW = U \cdot \Sigma \cdot V^T

    where Σ=diag(σ1,,σr)\Sigma = \text{diag}(\sigma_1, \dots, \sigma_r) contains singular values σi\sigma_i.

  • Fine-tuning with a Mask: Adjust the singular values with a learnable vector z\mathbf{z} (or mask m\mathbf{m}) to obtain new weights:

    W=Udiag(σi×mi)VTW' = U \cdot \text{diag}(\sigma_i \times m_i) \cdot V^T

    The mask m\mathbf{m} scales each singular value, modulating the contribution of each singular component in WW'.

Python-to-Math Variable Correspondence:

  • W in the code corresponds to a weight matrix of the model.
  • U, S, and V (extracted as decomposed_params[f"{param_name}.U"], ...S, ...V in the code) represent the components from the SVD of WW.
  • mm (from policy.get_mask(...)) is the mask m\mathbf{m}.
  • The resulting recomposed matrix W' is stored in variables like new_params[k].

Code Implementation:

  • compose_new_params Function (utils.py):
    mm = policy.get_mask(learnable_params[param_name])
    return (
    decomposed_params[f"{param_name}.U"]
    @ torch.diag_embed(decomposed_params[f"{param_name}.S"] * mm)
    @ decomposed_params[f"{param_name}.V"].T
    ) * normalization_factor
    • Here, mm corresponds to m\mathbf{m}.
    • decomposed_params[f"{param_name}.S"] * mm performs element-wise multiplication σi×mi\sigma_i \times m_i.
    • The expression reconstructs W=Udiag(σi×mi)VTW' = U \cdot \text{diag}(\sigma_i \times m_i) \cdot V^T
  • forward Function (utils.py):
    for k in base_params:
    if "mlp" in k:
    new_params[k] = compose_new_params(policy, k, decomposed_params, learnable_params)
    model.get_parameter(k).copy_(new_params[k])
    else:
    new_params[k] = base_params[k]
    • For each MLP-related weight kk, new_params[k] represents the updated weight WW' for that layer.

Understanding the Mask

A mask m\mathbf{m} is a vector of scaling factors applied to singular values:

σi=σi×mi\sigma_i' = \sigma_i \times m_i

In code, the mask mm modulates singular values as shown:

mm = policy.get_mask(learnable_params[param_name])

Here, mm[i] corresponds to mim_i scaling the ii-th singular value σi\sigma_i.

Multi-Layer Perceptron (MLP)

An MLP (Multi-Layer Perceptron) is a feed-forward neural network:

  • It consists of layers of interconnected neurons.
  • In transformer models, the MLP refers to the feed-forward portion in each layer following the self-attention mechanism.

When the code checks:

if "mlp" in k:
...

it targets parameters belonging to MLP layers, indicating that the SVF adjustments are applied specifically to these dense layers.


2. Reinforcement Learning Policy Gradient

Mathematical Concept:

Policy gradient methods update parameters θ\theta to maximize expected rewards:

θJ(θ)E[logp(yx;θ)r]\nabla_{\theta} J(\theta) \approx \mathbb{E}[-\log p(y|x; \theta) \cdot r]

with optional KL-divergence regularization:

loss=logp(yx;θ)r+λDKL(pθpref)\text{loss} = -\log p(y|x; \theta) \cdot r + \lambda D_{\text{KL}}(p_{\theta} \| p_{\text{ref}})

Python-to-Math Variable Correspondence:

  • xx and yy: Input prompt and generated output text.
  • θ\theta: Policy parameters updated during RL, corresponding to weights in objects like policy.
  • rr: Reward signal, computed from correctness of the model’s output.
  • α\alpha: In weighted combinations, these are the adaptive coefficients stored as adaptive_weights in the policy.

Code Implementation in Reinforce class:

  • Policy Gradient Calculation:
    log_likelihood = selected_log_probs.sum(axis=-1)
    pg = -log_likelihood * rewards[j]
    loss = pg
    if use_kl_loss:
    kl_div = F.kl_div(...)
    loss = loss + kl_ref_coeff * kl_div
    scaled_loss = loss / clipped_batch_size
    scaled_loss.backward()
    • log_likelihood computes logp(yx;θ)\log p(y|x; \theta).
    • pg = -log_likelihood * rewards[j] corresponds to logp(yx;θ)r-\log p(y|x; \theta) \cdot r.
    • KL divergence computation and addition mirror the regularized term λD_KL()\lambda D\_{\text{KL}}(\cdot).
  • Parameter Update:
    def update(self, policy):
    torch.nn.utils.clip_grad_norm_(policy.trainable_params, max_grad_norm)
    self.optimizer.step()
    self.optimizer.zero_grad()
    • This updates the RL policy parameters θ\theta using gradients derived from the computed loss.

3. Weighted Combination of Expert Models

Mathematical Concept:

Weighted combination uses coefficients α\alpha to blend expert weights:

Wcombined=i=1NαiW(i)W_{\text{combined}} = \sum_{i=1}^{N} \alpha_i \, W^{(i)}

where:

  • W(i)W^{(i)} are weight matrices from expert model ii.
  • αi\alpha_i are combination coefficients, analogous to variables in code like adaptive_weights.

Python-to-Math Variable Correspondence:

  • adaptive_weights: Corresponds to coefficients α\alpha.
  • vs: A list containing expert weights W(i)W^{(i)} for a specific parameter kk.
  • output_params[k]: Represents the combined weight W_combinedW\_{\text{combined}} for parameter kk.

Code Implementation in WeightedCombination class:

  • Combining Weights:
    def get_learnable_params(self):
    adaptive_coeff_per_layer = self.get_coeff_per_layer()
    output_params = {}
    for i, (k, vs) in enumerate(self.original_params.items()):
    cs_coeff = adaptive_coeff_per_layer[:, i]
    out = vs[0] * cs_coeff[0]
    for j, other_v in enumerate(vs[1:]):
    out = out + other_v * cs_coeff[j+1]
    output_params[k] = out
    return output_params
    • For each parameter key kk, this computes: Wk=i=1Nαi,kvk(i)W_k = \sum_{i=1}^{N} \alpha_{i,k} \, v^{(i)}_k where cs_coeff[j] corresponds to α_j,k\alpha\_{j,k} and each vs[j] corresponds to W(j)_kW^{(j)}\_k.
    • The result output_params[k] is the weighted combination W_combinedW\_{\text{combined}} for that parameter.

4. Summary of Variable Correspondences

  • W,WW, W' (e.g., base_params[k], new_params[k]): Weight matrices before and after fine-tuning.
  • U,S,VU, S, V (e.g., decomposed_params[...]): Matrices from SVD decomposition used in SVF.
  • σi\sigma_i (e.g., elements of decomposed_params[f"{param_name}.S"]): Singular values of weight matrices.
  • m\mathbf{m} or mask (e.g., mm): Scaling factors applied to singular values during fine-tuning.
  • θ\theta (implied in policy parameters): Parameters of the policy being optimized via reinforcement learning.
  • rr (e.g., rewards[j]): Reward signal for a given output, derived from task correctness.
  • α\alpha (e.g., adaptive_weights, adaptive_coeff_per_layer): Coefficients for weighted combination of expert models.
  • x,yx, y (e.g., prompt, result.generation): Input prompts and generated outputs during evaluation.

Conclusion

This framework intricately ties advanced mathematics to practical code implementation. By decomposing weight matrices via SVD and adjusting singular values with a learned mask m\mathbf{m}, the system fine-tunes MLP components effectively. Reinforcement learning optimizes policy parameters θ\theta using policy gradients influenced by rewards rr. Furthermore, a weighted combination policy uses coefficients α\alpha to blend expert models, dynamically adapting the LLM for specific tasks.

Understanding how Python variables like mm, adaptive_weights, and new_params correspond to mathematical symbols such as m\mathbf{m}, α\alpha, and WW' helps demystify complex adaptive learning strategies, providing insight into how self-adaptive LLMs can be engineered to excel across diverse tasks.