Building Tiny-Llama from ground blocks Decoder-only with Reinforcement Learning from Human Feedback (Human-in-loop) from Scratch

YOUNESS-ELBRAG
13 min readApr 20, 2024

--

Introduction:

In this article, we’ll delve into the process of building Tiny-Llama, a scaled-down version of the Llama architecture, from the ground up. Our focus will be on implementing the architecture with a reinforcement learning approach that leverages human feedback, making it a Human-in-loop system. Inspired by Karpathy’s Makemore series, we aim to provide practical insights gained from our experience in implementing Decoder-Only Block Transformer .

Content: We’ll be loosely following the structure of the original paper, but with a focus on practical implementation rather than publication-oriented formatting. Certain steps, such as setting up a virtual environment and installing dependencies, will be assumed to be already completed.

Part One Building Tiny-Llama from ground blocks Decoder-only

Key Enhancements: Our implementation of Tiny-Llama incorporates several key enhancements inspired by recent advancements in transformer models:

  1. RMS-Normalization: RMSNorm is a simplification of the original layer normalization (LayerNorm). LayerNorm is a regularization technique that might handle the internal covariate shift issue so as to stabilize the layer activations and improve model convergence. It has been proved quite successful in LLaMA 2.
  2. Activation Function: LLaMA 2 uses the SwiGLU activation function instead of ReLU, leading to improved training performance.
  3. Rotary Positional Embeddings (RoPE): Inspired by the GPT-Neo-X project, LLaMA 2 incorporates rotary positional embeddings at each layer, enhancing the model’s positional understanding
  4. Increased Context Length and Grouped-Query Attention (GQA): LLaMA 2 model has a doubled context window (from 2048 to 4096 tokens) and employs grouped-query attention. This allows for better processing of long documents, chat histories, and summarization tasks.

1. KV-Caching for Efficient Inference

KV-caching is a pivotal optimization method employed in this implementation to expedite the Language Model (LM) decoding process. In autoregressive decoding, where each token’s prediction depends on preceding tokens, the model’s self-attention mechanism is causal. Consequently, a token’s representation is solely based on itself and the preceding tokens, excluding future ones.

Within self-attention, the input sequence undergoes key, value, and query projections. The KV-cache adeptly stores the outcomes of these key and value projections, obviating the need for redundant computations in subsequent decoding iterations. Consequently, tokens with unchanging representations during autoregressive decoding can be retrieved from the cache, markedly augmenting the inference speed.

This KV-caching technique stands as a pivotal architectural element, amplifying the efficiency and swiftness of the LLaMA model during decoding.

2. Grouped-Query Attention (GQA) for Enhanced Efficiency

The LLaMA 2 model integrates a variant of Multi-Query Attention (MQA), a concept introduced by Shazeer (2019) as an enhancement to the Multi-Head Attention (MHA) algorithm. MQA aims to bolster the efficiency of attention mechanisms while preserving minimal accuracy loss.

In traditional multi-head attention, the attention computation is replicated h times, with h representing the number of attention heads. However, MQA reduces computational redundancy by modifying the dimensionality of the K and V values, either by removing or significantly reducing the number of heads dimension (h). In MQA, each “head” of the query value (Q) undergoes identical K and V transformations, streamlining the attention computation process.

This optimization yields comparable computational performance to MHA while drastically reducing the volume of data read/written from memory. Consequently, MQA enhances both performance, through heightened arithmetic intensity, and memory space efficiency, by diminishing the quantity of KV-cache data stored. As a result, MQA emerges as a valuable addition to the LLaMA architecture.

3. Rotary Positional Embeddings for Enhanced Attention

Within the LLaMA 2 model, Rotary Positional Embeddings (RoPE) play a pivotal role in augmenting attention mechanisms by integrating positional context into token representations. While the concept of “attention” is potent, ensuring the meaningfulness of calculated attention necessitates tokens to possess positional awareness.

Position embeddings are typically categorized into absolute and relative types. Absolute position embeddings encode a word’s absolute position within the input phrase, whereas relative position embeddings capture the positional relationship between two words. These embeddings furnish crucial positional insights that aid tokens in comprehending their sequential context.

Rotary Positional Embeddings introduce a novel approach by employing rotation matrices to embed positional data. The objective is to ensure that the inner product of vectors q and k, positioned at m and n, exclusively depends on q, k, and their relative distance (m — n). This is achieved by embedding the rotation matrix, with the angle representing the vector’s position, into the original vector via matrix multiplication, aligning with this criterion.

This innovative strategy for integrating positional context elevates the model’s capacity to grasp token relationships and contextual nuances, thereby enhancing attention mechanisms.

5. SwiGLU Activation Function

In the transformer architecture, the output of the attention layer passes through a simple multi-layer perceptron. While the ReLU activation function, which simply zeros negative input values, has been widely used, variants have been proposed to enhance model stability and convergence.

In “LLama 2: Open Foundation and Fine-Tuned Chat Models,” the authors mention the utilization of the SwiGLU variant to train their large language model. In “GLU Variants Improve Transformer,” the authors empirically demonstrate that certain flavors of a special class of activation functions named Gated Linear Units (GLU) provide improvements over regular ReLU.

6. RMSNorm Layer

In the realm of NLP, it has been customary to use layer normalization to enhance training stability by centering and scaling the input distribution, thereby providing robustness against noise. However, this introduces computational overhead for large and deep networks due to the need to calculate the input’s mean and standard deviation.

In “Root Mean Square Layer Normalization,” the authors argue that the real advantage of layer normalization lies in rescaling invariance rather than offset invariance. They propose to simply rescale the inputs using the root mean square statistic.

Decoder-Only

The model learns a latent representation of the language in a self-supervised manner with a surprisingly simple approach: given a large corpus, sequences of fixed size are randomly sampled to construct the batched context as input to the model. The targets are those same sequences shifted by one element, enabling the model to learn to predict the next token given a context by minimizing the cross-entropy through gradient descent:

Part 2 Reinforcement Learning from Human Feedback (Human-in-loop)

RLHF or reinforcement learning from human feedback is a machine learning technique used to optimize the performance of an AI model through integration of human insight and reinforcement learning. RLHF is highly instrumental in training and fine-tuning GenAI models, including large language models (LLMs), to perform a host of tasks, such as realistic content generation, responding to questions, and code generation. LLMs trained through RLHF generate outputs that are informative and in alignment with human values.

RLHF is a subtle approach to dealing with tasks that have complex, ill-defined and difficult-to-specify goals. It has widely been leveraged to boost the accuracy, relevance and ethics of gen AI models

LLM, or Language Learning Model, follows a three-step process to predict the next token and generate meaningful answers: pre-training, supervised fine-tuning (SFT), and reinforcement learning from human feedback (RLHF).

  1. Pre-training: This initial step involves training a model, such as Transformers or Mamba architecture, to predict the next token. While this model can complete sentences, it lacks knowledge of when to stop generating text. Mistral-7B-v0.1 exemplifies a pre-trained model, with over 90% of training dedicated to this step.
  2. Supervised Fine-Tuning (SFT): In SFT, the model is trained on high-quality pairs of data, typically request-response pairs. This phase aims to teach the model how to generate answers and when to stop. Although the training objective remains predicting the next token, the data quality improves to reflect real-world usage.
  3. Reinforcement Learning from Human Feedback (RLHF): RLHF is pivotal in refining LLMs to generate answers that align with human preferences. Experts compare different model-generated answers to a question and provide feedback. This feedback is used to train a reward model, which guides further optimization of the LLM after SFT.

After these steps, the LLM can generate meaningful answers. Models trained through all three steps are labeled with “Instruct” or “Chat” in their name, like Mistral-7B-Instruct-v0.1. A recent advancement, Direct Preference Optimization (DPO), eliminates the need for a separate reward model, directly optimizing the LLM.

An LLM, at each stage, represents a model trained for its respective purpose. After pre-training, it can be further fine-tuned for specific tasks. SFT and RLHF, when combined, significantly enhance model performance, suitable for chat applications and API integration. Only after the final step can the model effectively integrate with systems like RAG (Retrieval-Augmented Generation).

Reinforcement learning with human feedback (RLHF), which garnered a lot of limelight recently, has started a new revolution in the application of RL techniques in the field of NLP, especially large language models (LLMs). In this blog, we will learn the complete RLHF training pipeline for an LLM using the Huggingface library.

The RLHF pipeline consists of 3 phases:

  • Domain Specific Pre-Training: Fine-tune a pre-trained LLM on raw text with a Causal Language Modelling Objective.
  • Supervised fine-tuning: Fine-tune the domain-specific LLM on task-specific as well as domain-specific (prompt/instruction, response) pairs.
  • RLHF
    Reward model training: Training a language model to classify responses as good or bad (thumbs up, thumbs down)
    RLHF fine-tuning: Using the reward model training on (prompt, good_response, bad_response) data labeled by human experts to align the responses on the LLM

1. Domain Specific Pre-training

Domain-specific pre-training is a step where you provide your language model with domain knowledge of its ultimate application area. This step, where the model is fine-tuned using causal language modeling (next token prediction), is much similar to when a model is trained from scratch on a corpus of raw domain-specific text data. In this case, however, the data required is much less, given that the model is pre-trained on trillions of tokens. Below is an implementation of the domain-specific pre-training method:

Final Thoughts

In this article, we have briefly introduced the pipeline that many researchers and engineers have used to create their own domain-specific LLMs that are aligned with human preferences. Keep in mind that RLHF requires a high-quality curated dataset that is labeled by a human expert who has graded previous LLM responses (human-in-the-loop). So, this process is costly and slow. Apart from RLHF, newer techniques such as DPO (Direct Preference Optimization) and RLAIF (Reinforcement Learning with AI Feedback) exist. These methods are shown to be more cost-effective and quicker than RLHF. However, many of the underlying principles stay the same.

2. Supervised fine-tuning

The output of this domain-specific pre-training step is a model that can recognize the context of input text and predict the next words/sentences. This model also resembles a typical sequence-to-sequence model. However, it is not designed to respond to prompts. Performed supervised fine-tuning with prompt-text pairs, is a cost-effective method of injecting domain-specific as well as task-specific knowledge into a pre-trained LLM and having it respond to context-specific questions. Below is the implementation of supervised fine-tuning using HuggingFace. This step is also referred to as instruction fine-tuning.

The result of this step is a model (LLM) that resembles a chat agent

3. Reward model training

The RLHF training strategy is used to ensure that the LLM is aligned with human preferences and produces better outputs. For this purpose, the reward model is trained to output a score for a (prompt, response) pair. This can be modeled as a simple classification task. The reward model uses data labeled preference by expert human annotators as input Reward Model

from __future__ import annotations

import argparse
import math
import string
from itertools import chain
from typing import Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from nltk.corpus import words
from tabulate import tabulate
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import (
AutoModel,
AutoModelForCausalLM,
LlamaConfig,
PreTrainedModel,
PreTrainedTokenizer,
get_scheduler,
set_seed,
)
import torch
import torch.nn as nn
import torch.nn.functional as F

class ValueModel(nn.Module):
def __init__(self, transformer: nn.Module, device=None, dtype=None) -> None:
super().__init__()
self.transformer = transformer
self.v_head = nn.Linear(transformer.config.hidden_size, 1, bias=False, device=device, dtype=dtype)

def forward(self, *args, **kwargs) -> torch.Tensor:
hidden_states = self.transformer(*args, **kwargs, use_cache=False).last_hidden_state
values = self.v_head(hidden_states).squeeze(-1)
return values
class GoldenRewardModel:
def __init__(self, tokenizer: PreTrainedTokenizer) -> None:
self.tokenizer = tokenizer

def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> torch.Tensor:
batch_size, seq_len = input_ids.shape
input_ids_list = input_ids.tolist()
prompt_length = input_ids_list[0].index(self.tokenizer.eos_token_id) + 1

scores = [[0 for _ in range(seq_len)] for _ in range(batch_size)]
for input_id, score in zip(input_ids_list, scores):
prompt_id = input_id[:prompt_length]
target_id = [x for x in prompt_id if x != self.tokenizer.pad_token_id]
response_id = input_id[prompt_length:]
for j, (rsp_id, tgt_id) in enumerate(zip(response_id, target_id)):
if rsp_id != tgt_id:
break
score[prompt_length + j] = 1

return torch.tensor(scores, dtype=torch.float32, device=input_ids.device)

RLHF fine-tuning (for alignment)

Finally, in this step, we will train the SFT model from step 1, to generate outputs that maximize the scores of the reward model. Essentially, we will use the reward model to tune the outputs of the supervised model so that it produces human-like responses. Research has shown that in the presence of high-quality preference data, models that undergo RLHF are superior to SFT models. This training is performed using a reinforcement learning method called Proximal Policy Optimization (PPO).

Proximal Policy Optimization is a reinforcement learning algorithm introduced by OpenAI in 2017. Initially being used as one of the top-performing deep reinforcement algorithms for 2D and 3D control problems (video games, Go, 3D locomotion), PPO has now found a place in NLP, specifically in the RLHF pipeline. For a more detailed overview of the PPO algorithm, refer to the link here let us Code PPO from Hugging based class .

class PPOTrainer:
def __init__(
self,
args,
tokenizer: PreTrainedTokenizer,
policy_model: PreTrainedModel,
value_model: nn.Module,
reward_model: Callable,
optimizer: torch.optim.Optimizer,
) -> None:
self.config = args
self.tokenizer = tokenizer
self.policy_model = policy_model
self.value_model = value_model
self.reward_model = reward_model
self.optimizer = optimizer

def compute_advantages(
self,
values: torch.Tensor,
rewards: torch.Tensor,
mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
lastgaelam = 0
advantages_reversed = []
gen_len = rewards.shape[-1]

values = values * mask
rewards = rewards * mask

for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)

returns = advantages + values
advantages = masked_whiten(advantages, mask)
advantages = advantages.detach()
return values, advantages, returns

@torch.no_grad()
def sample_experience(self, prompt_ids: torch.Tensor, prompt_mask: torch.Tensor) -> dict[str, torch.Tensor]:
self.policy_model.eval()
self.value_model.eval()

_, prompt_length = prompt_ids.shape

input_ids = self.policy_model.generate(
input_ids=prompt_ids,
attention_mask=prompt_mask,
max_length=self.config.max_length,
use_cache=True,
do_sample=self.config.temp > 0,
temperature=self.config.temp,
top_p=self.config.top_p,
eos_token_id=[self.tokenizer.eos_token_id, self.tokenizer.pad_token_id],
pad_token_id=self.tokenizer.pad_token_id,
)

attention_mask = (input_ids != self.tokenizer.pad_token_id).long()
old_logits = self.policy_model(input_ids, attention_mask=attention_mask).logits
rewards = self.reward_model(input_ids=input_ids, attention_mask=attention_mask)
values = self.value_model(input_ids=input_ids, attention_mask=attention_mask)

old_logprobs = logprobs_from_logits(old_logits[:, :-1, :], input_ids[:, 1:])

mask = attention_mask[:, prompt_length:]
rewards = rewards[:, prompt_length:]
values = values[:, prompt_length - 1 : -1]
old_logprobs = old_logprobs[:, prompt_length - 1 :]

values, advantages, returns = self.compute_advantages(values, rewards, mask)

return dict(
prompt_ids=prompt_ids,
old_logprobs=old_logprobs,
values=values,
rewards=rewards,
input_ids=input_ids,
attention_mask=attention_mask,
advantages=advantages,
returns=returns,
)

def train_minibatch(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
self.policy_model.train()
self.value_model.train()

prompt_ids = inputs["prompt_ids"]
old_logprobs = inputs["old_logprobs"]
rewards = inputs["rewards"]
values = inputs["values"]
attention_mask = inputs["attention_mask"]
input_ids = inputs["input_ids"]
advantages = inputs["advantages"]
returns = inputs["returns"]

_, prompt_length = prompt_ids.shape
mask = attention_mask[:, prompt_length:]

logits = self.policy_model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
logprobs = logprobs[:, prompt_length - 1 :]

entropy = masked_mean(entropy_from_logits(logits[:, prompt_length - 1 : -1]), mask)
approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)
policykl = masked_mean(old_logprobs - logprobs, mask)

ratio = torch.exp(logprobs - old_logprobs)
pg_losses1 = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)
pg_loss = masked_mean(torch.max(pg_losses1, pg_losses2), mask)
pg_clipfrac = masked_mean((pg_losses2 > pg_losses1).float(), mask)

vpreds = self.value_model(input_ids=input_ids, attention_mask=attention_mask)
vpreds = vpreds[:, prompt_length - 1 : -1]

vpredclipped = torch.clamp(
vpreds, min=values - self.config.cliprange_value, max=values + self.config.cliprange_value
)
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), mask)

loss = pg_loss + self.config.vf_coef * vf_loss

loss.backward()
pg_grad_norm = nn.utils.clip_grad_norm_(self.policy_model.parameters(), max_norm=self.config.max_grad_norm)
vf_grad_norm = nn.utils.clip_grad_norm_(
self.value_model.parameters(), max_norm=self.config.max_grad_norm * self.config.vf_coef
)
self.optimizer.step()
self.optimizer.zero_grad()

stats = {
**{f"lr/group_{i}": pg["lr"] for i, pg in enumerate(self.optimizer.param_groups)},
"loss/policy": pg_loss.item(),
"loss/value": vf_loss.item(),
"loss/total": loss.item(),
"policy/grad_norm": pg_grad_norm.item(),
"policy/entropy": entropy.item(),
"policy/approxkl": approxkl.item(),
"policy/policykl": policykl.item(),
"policy/clipfrac": pg_clipfrac.item(),
"policy/advantages_mean": masked_mean(advantages, mask).item(),
"policy/advantages_var": masked_var(advantages, mask).item(),
"policy/ratio_mean": masked_mean(ratio, mask).item(),
"returns/mean": masked_mean(returns, mask).item(),
"returns/var": masked_var(returns, mask).item(),
"val/grad_norm": vf_grad_norm.item(),
"val/vpred": masked_mean(vpreds, mask).item(),
"val/error": masked_mean(vf_losses1, mask).item(),
"val/clipfrac": vf_clipfrac.item(),
"val/mean": masked_mean(values, mask).item(),
"val/var": masked_var(values, mask).item(),
"env/reward_mean": masked_mean(rewards, mask).item(),
"env/reward_var": masked_var(rewards, mask).item(),
"env/reward_total": rewards.sum(1).mean().item(),
}
return stats


class GoldenRewardModel:
def __init__(self, tokenizer: PreTrainedTokenizer) -> None:
self.tokenizer = tokenizer

def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> torch.Tensor:
batch_size, seq_len = input_ids.shape
input_ids_list = input_ids.tolist()
prompt_length = input_ids_list[0].index(self.tokenizer.eos_token_id) + 1

scores = [[0 for _ in range(seq_len)] for _ in range(batch_size)]
for input_id, score in zip(input_ids_list, scores):
prompt_id = input_id[:prompt_length]
target_id = [x for x in prompt_id if x != self.tokenizer.pad_token_id]
response_id = input_id[prompt_length:]
for j, (rsp_id, tgt_id) in enumerate(zip(response_id, target_id)):
if rsp_id != tgt_id:
break
score[prompt_length + j] = 1

return torch.tensor(scores, dtype=torch.float32, device=input_ids.device)

Final Thoughts

In this article, we have briefly introduced the pipeline that many researchers and engineers have used to create their own domain-specific LLMs that are aligned with human preferences and buiding TinyLlama Model Building Blocks . Keep in mind that RLHF requires a high-quality curated dataset that is labeled by a human expert who has graded previous LLM responses (human-in-the-loop). So, this process is costly and slow. Apart from RLHF, newer techniques such as DPO (Direct Preference Optimization) and RLAIF (Reinforcement Learning with AI Feedback) exist. These methods are shown to be more cost-effective and quicker than RLHF. However, many of the underlying principles stay the same.

Concludion

In summary, our article details of builing tinyLLama model, enriched by Reinforcement Learning with Hybrid Feedback (RLHF). By integrating Transformer architecture with fine-tuned configurations, ByteTokenizer, and PromptDataset, we engineered a compact yet potent language model. The incorporation of RLHF techniques, notably Proximal Policy Optimization (PPO) via the PPOTrainer class and the GoldenRewardModel, enhanced model training by combining reinforcement learning principles with human feedback. Full code in Github Link Here

References

  1. https://huyenchip.com/2023/05/02/rlhf.html#why_sft
  2. https://huggingface.co/docs/trl/sft_trainer
  3. https://huggingface.co/docs/trl/reward_trainer
  4. https://huggingface.co/docs/trl/ppo_trainer
  5. LLaMA: Open and Efficient Foundation Language Models
  6. Llama 2: Open Foundation and Fine-Tuned Chat Model
  7. Attention Is All You Need
  8. RoFormer: Enhanced Transformer with Rotary Position Embedding
  9. GLU Variants Improve Transformer
  10. Root Mean Square Layer Normalization
  11. SentencePiece

--

--

YOUNESS-ELBRAG

Machine Learning Engineer || AI Archituct @AIGOT I explore advanced Topic in AI special Geometry Deep learning