Llama-2 finetuning using PEFT and Multi-GPU setup
ChatGPT and Gemini are great models for general prompting. However, due to model alignment, they refuse to respond to prompts related to research which involves toxic language. Llama-2 provides an open-source alternative to train an unaligned model. Thus, for one of my recent research, we needed to fine-tune a Llama-2 model. For the larger models, I also needed multi-gpu setup to fit the model in memory for training, especially due to large context sizes. However, I soon realized that open-source implementations of LoRA training implementations are limited. Multi-GPU inference on the other hand is as simple as using auto
for the device mapping in the hugging face implementation.
“There’s two strategies that have been shown to work: Gpipe-style model parallelism, and tensor parallelism. HF Accelerate and Deepspeed both support the former. However sadly they don’t properly support LoRA at present.” -Jeremy Howard
In addition to the limitations of accelerate and deepspeed, I also tried #anyscale’s ray/alpa to further disappointment. Perhaps thing’s have changed recently since I started working on this. However, with my solution, I never needed to look back.
I have always found Jax to be convenient to work with and much faster than pytorch, so I implemented LoRA and GPU tensor parallelism on top of ayaka14732’s excellent implementation of Llama-2 on JAX. In addition, I change the codebase to implement the following changes:
- Alpaca format instruct dataset loading
- Parameter configuration for 13B models
- Hand-tuned sharding for projection and attention parameters for optimized distribution of parameters across GPU’s
- Merging of parameters to huggingface after training (for using other huggingface compatible libraries)
- python 3.9 support
Update 03/29/2024 Several Quality of Life changes have been made to my codebase. I also published a preprint for my library. Evaluation of my approach shows over 12x performance improvement over HuggingFace PEFT/Microsoft DeepSpeed.
#machinelearning #generativeai