I read the paper and code do you dont have to: Self-Adaptive Large Language Models
I read the paper and code do you dont have to: Self-Adaptive Large Language Models
Table of Contents
The work described in this blog post is based on the research and code from SakanaAI’s Transformer²: Self-adaptive LLMs project. The paper presents innovative methods for dynamically adapting large language models (LLMs) to new tasks, and the accompanying codebase implements these concepts in practice.
Paper: Transformer²: Self-Adaptive LLMs
Code: SakanaAI’s Transformer² GitHub Repository
Blog: SakanaAI’s Transformer² Blog Post
1. The Mathematical Backbone: SVF, RL, and Weighted Combinations
Singular Value Fine-tuning (SVF)
SVF leverages Singular Value Decomposition (SVD) to fine-tune LLMs efficiently:
-
SVD of a Weight Matrix:
where contains singular values .
-
Fine-tuning with a Mask: Adjust the singular values with a learnable vector (or mask ) to obtain new weights:
The mask scales each singular value, modulating the contribution of each singular component in .
Python-to-Math Variable Correspondence:
W
in the code corresponds to a weight matrix of the model.U
,S
, andV
(extracted asdecomposed_params[f"{param_name}.U"]
,...S
,...V
in the code) represent the components from the SVD of .mm
(frompolicy.get_mask(...)
) is the mask .- The resulting recomposed matrix
W'
is stored in variables likenew_params[k]
.
Code Implementation:
compose_new_params
Function (utils.py
):mm = policy.get_mask(learnable_params[param_name])return (decomposed_params[f"{param_name}.U"]@ torch.diag_embed(decomposed_params[f"{param_name}.S"] * mm)@ decomposed_params[f"{param_name}.V"].T) * normalization_factor- Here,
mm
corresponds to . decomposed_params[f"{param_name}.S"] * mm
performs element-wise multiplication .- The expression reconstructs
- Here,
forward
Function (utils.py
):for k in base_params:if "mlp" in k:new_params[k] = compose_new_params(policy, k, decomposed_params, learnable_params)model.get_parameter(k).copy_(new_params[k])else:new_params[k] = base_params[k]- For each MLP-related weight ,
new_params[k]
represents the updated weight for that layer.
- For each MLP-related weight ,
Understanding the Mask
A mask is a vector of scaling factors applied to singular values:
In code, the mask mm
modulates singular values as shown:
mm = policy.get_mask(learnable_params[param_name])
Here, mm[i]
corresponds to scaling the -th singular value .
Multi-Layer Perceptron (MLP)
An MLP (Multi-Layer Perceptron) is a feed-forward neural network:
- It consists of layers of interconnected neurons.
- In transformer models, the MLP refers to the feed-forward portion in each layer following the self-attention mechanism.
When the code checks:
if "mlp" in k: ...
it targets parameters belonging to MLP layers, indicating that the SVF adjustments are applied specifically to these dense layers.
2. Reinforcement Learning Policy Gradient
Mathematical Concept:
Policy gradient methods update parameters to maximize expected rewards:
with optional KL-divergence regularization:
Python-to-Math Variable Correspondence:
- and : Input prompt and generated output text.
- : Policy parameters updated during RL, corresponding to weights in objects like
policy
. - : Reward signal, computed from correctness of the model’s output.
- : In weighted combinations, these are the adaptive coefficients stored as
adaptive_weights
in the policy.
Code Implementation in Reinforce
class:
- Policy Gradient Calculation:
log_likelihood = selected_log_probs.sum(axis=-1)pg = -log_likelihood * rewards[j]loss = pgif use_kl_loss:kl_div = F.kl_div(...)loss = loss + kl_ref_coeff * kl_divscaled_loss = loss / clipped_batch_sizescaled_loss.backward()
log_likelihood
computes .pg = -log_likelihood * rewards[j]
corresponds to .- KL divergence computation and addition mirror the regularized term .
- Parameter Update:
def update(self, policy):torch.nn.utils.clip_grad_norm_(policy.trainable_params, max_grad_norm)self.optimizer.step()self.optimizer.zero_grad()
- This updates the RL policy parameters using gradients derived from the computed loss.
3. Weighted Combination of Expert Models
Mathematical Concept:
Weighted combination uses coefficients to blend expert weights:
where:
- are weight matrices from expert model .
- are combination coefficients, analogous to variables in code like
adaptive_weights
.
Python-to-Math Variable Correspondence:
adaptive_weights
: Corresponds to coefficients .vs
: A list containing expert weights for a specific parameter .output_params[k]
: Represents the combined weight for parameter .
Code Implementation in WeightedCombination
class:
- Combining Weights:
def get_learnable_params(self):adaptive_coeff_per_layer = self.get_coeff_per_layer()output_params = {}for i, (k, vs) in enumerate(self.original_params.items()):cs_coeff = adaptive_coeff_per_layer[:, i]out = vs[0] * cs_coeff[0]for j, other_v in enumerate(vs[1:]):out = out + other_v * cs_coeff[j+1]output_params[k] = outreturn output_params
- For each parameter key , this computes:
where
cs_coeff[j]
corresponds to and eachvs[j]
corresponds to . - The result
output_params[k]
is the weighted combination for that parameter.
- For each parameter key , this computes:
where
4. Summary of Variable Correspondences
- (e.g.,
base_params[k]
,new_params[k]
): Weight matrices before and after fine-tuning. - (e.g.,
decomposed_params[...]
): Matrices from SVD decomposition used in SVF. - (e.g., elements of
decomposed_params[f"{param_name}.S"]
): Singular values of weight matrices. - or mask (e.g.,
mm
): Scaling factors applied to singular values during fine-tuning. - (implied in policy parameters): Parameters of the policy being optimized via reinforcement learning.
- (e.g.,
rewards[j]
): Reward signal for a given output, derived from task correctness. - (e.g.,
adaptive_weights
,adaptive_coeff_per_layer
): Coefficients for weighted combination of expert models. - (e.g.,
prompt
,result.generation
): Input prompts and generated outputs during evaluation.
Conclusion
This framework intricately ties advanced mathematics to practical code implementation. By decomposing weight matrices via SVD and adjusting singular values with a learned mask , the system fine-tunes MLP components effectively. Reinforcement learning optimizes policy parameters using policy gradients influenced by rewards . Furthermore, a weighted combination policy uses coefficients to blend expert models, dynamically adapting the LLM for specific tasks.
Understanding how Python variables like mm
, adaptive_weights
, and new_params
correspond to mathematical symbols such as , , and helps demystify complex adaptive learning strategies, providing insight into how self-adaptive LLMs can be engineered to excel across diverse tasks.