|
--- |
|
license: apache-2.0 |
|
language: en |
|
tags: |
|
- llama |
|
- instruction-residual |
|
- parameter-efficient |
|
- safetensors |
|
- transformers |
|
base_model: |
|
- meta-llama/Llama-3.1-8B |
|
- meta-llama/Llama-3.1-8B-Instruct |
|
--- |
|
|
|
# Llama-3.1-8b-Instruct-Residual |
|
|
|
**Full-rank instruction residual for Llama-3.1-8B** |
|
|
|
This repository provides the **full-rank instruction residual** \(Δθ = θ_{instruct} - θ_{base}\) between the instruction-tuned Llama-3.1-8B-Instruct model and its corresponding base Llama-3.1-8B model. By adding this residual to a fresh base checkpoint, you can restore instruction-following capabilities **without** running a full fine-tuning cycle. |
|
|
|
## How it was created |
|
|
|
We follow the *instruction residual* approach introduced by Jindal et al. (2024): |
|
|
|
> “In this section, we describe the instruction residual approach to simply regain the instruction following capabilities. We compute the instruction residual between an instruction following LLM \(θ_{i,d_1,v_1}\) and its corresponding base model \(θ_{b,d_1}\) in the parametric space as |
|
> \[ |
|
> Θ_{r,v_1} = θ_{i,d_1,v_1} - θ_{b,d_1}. |
|
> \] |
|
> This tensor subtraction extracts the instruction-specific information, which can then be added to any base model.” |
|
|
|
The full paper is available at: https://arxiv.org/abs/2410.10739 |
|
|
|
## Files |
|
|
|
- `pytorch_model.safetensors` — full-rank FP16 residual weights (~16 GB). |
|
- `config.json` — configuration matching the Llama-3.1-8B architecture. |
|
- `README.md` — this model card. |
|
|
|
## Usage |
|
|
|
Below is a minimal example showing how to apply the residual to a base model: |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM |
|
from safetensors.torch import load_file |
|
import torch |
|
|
|
# 1) Load base |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"meta-llama/Llama-3.1-8B", |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
) |
|
|
|
# 2) Load residual |
|
residual_sd = load_file("pytorch_model.safetensors", device="cpu") |
|
|
|
# 3) Apply residual |
|
for name, delta in residual_sd.items(): |
|
param = dict(model.named_parameters())[name] |
|
param.data += delta.to(param.device).to(param.dtype) |
|
|
|
# 4) Save or push |
|
model.save_pretrained("llama-3.1-8b-base-plus-instruct") |
|
``` |
|
|
|
For full scripts, see the `examples/` folder. |
|
|
|
## Intended Use & Limitations |
|
|
|
- **Intended Use**: Add instruction-following capabilities to Llama-3.1-8B base models. |
|
- **Limitations**: |
|
- Residual must match the exact base checkpoint. |
|
- Stored in FP16 (~16 GB); dequantization needed if working in 4-bit. |
|
- Applying to mismatched architectures will produce invalid weights. |
|
|
|
## License |
|
|
|
This residual is released under the **Apache License 2.0**. See the `LICENSE` file for details. |
|
|
|
## References |
|
As mentioned before this method was introduced by **Jindal et al., 2024**, arXiv:2410.10739.: |
|
|
|
```bibtex |
|
@misc{jindal2024balancingcontinuouspretraininginstruction, |
|
title={Balancing Continuous Pre-Training and Instruction Fine-Tuning: Optimizing Instruction-Following in LLMs}, |
|
author={Ishan Jindal and Chandana Badrinath and Pranjal Bharti and Lakkidi Vinay and Sachin Dev Sharma}, |
|
year={2024}, |
|
eprint={2410.10739}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.CL}, |
|
url={https://arxiv.org/abs/2410.10739}, |
|
} |
|
``` |
|
|
|
|