从头开始使用PyTorch构建自己的Llama 3架构
ztj100 2024-11-14 19:24 18 浏览 0 评论
从头开始使用PyTorch构建自己的Llama 3架构
构建Llama 3模型完整架构的逐步指南,从零开始,并在自定义数据集上进行训练和推断。
[图片来自作者]: Llama 3架构显示训练和推断流程。我想象了这个图,因为官方的Llama 3论文中没有这个图。到本文结尾时,我相信你应该能够绘制出比这个更好的架构。
通过本文你将实现什么?
- 您将深入了解 Llama 3 模型的每个组件在背后的工作原理。
- 您将编写代码来构建Llama 3的每个组件,然后将它们全部组合在一起以构建一个功能完全的Llama 3模型。
- 您还将编写代码以使用新的自定义数据集训练您的模型。
- 您还将编写代码来执行推理,以便您的 Llama 3 模型能够根据输入提示生成新文本。
先决条件
现在我们知道我们想要实现什么,让我们一步一步开始构建一切。
步骤 1:输入块
如上面的Llama 3架构图所示,输入块有三个组成部分:- 文本/提示、分词器和嵌入。
输入块内部的组件是如何工作的? 有一句流行的话说“一个图片胜过千言万语”,让我们看看下面的流程图,以了解输入块内部的工作流程。
[作者图片]: 输入块流程图显示了提示、分词器和嵌入流程。
让我们编写输入块代码:
# 导入必要的库
import torch
from torch import nn
from torch.nn import functional as F
import math
import numpy as np
import time
from dataclasses import dataclass
from typing import Optional, Tuple, List
import pandas as pd
from matplotlib import pyplot as plt
### 步骤 1: 输入块 ###
# 使用Tiny Shakespeare数据集进行字符级标记化。以下字符级标记器的部分内容参考了Andrej karpathy的GitHub (https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare_char/prepare.py),我发现其解释得非常好。
# 加载tiny_shakespeare数据文件 (https://github.com/tamangmilan/llama3/blob/main/tiny_shakespeare.txt)
device: str = 'cuda' if torch.cuda.is_available() else 'cpu' # 根据可用性将设备分配为cuda或cpu
# 加载tiny_shakespeare数据文件。
with open('tiny_shakespeare.txt', 'r') as f:
data = f.read()
# 通过提取tiny_shakespeare数据中的所有唯一字符来准备词汇表
vocab = sorted(list(set(data)))
# 训练Llama 3模型需要额外的标记,例如<|begin_of_text|>、<|end_of_text|>和<|pad_id|>,我们将把它们添加到词汇表中
vocab.extend(['<|begin_of_text|>','<|end_of_text|>','<|pad_id|>'])
vocab_size = len(vocab)
# 创建字符与词汇表中对应整数索引之间的映射。
# 这对于构建标记器的编码和解码函数非常重要。
itos = {i:ch for i, ch in enumerate(vocab)}
stoi = {ch:i for i, ch in enumerate(vocab)}
# 标记器编码函数:接受一个字符串,输出一个整数列表
def encode(s):
return [stoi[ch] for ch in s]
# 标记器解码函数:接受一个整数列表,输出一个字符串
def decode(l):
return ''.join(itos[i] for i in l)
# 定义张量标记变量,以便在模型训练期间使用
token_bos = torch.tensor([stoi['<|begin_of_text|>']], dtype=torch.int, device=device)
token_eos = torch.tensor([stoi['<|end_of_text|>']], dtype=torch.int, device=device)
token_pad = torch.tensor([stoi['<|pad_id|>']], dtype=torch.int, device=device)
prompts = "Hello World"
encoded_tokens = encode(prompts)
decoded_text = decode(encoded_tokens)
### 测试: 输入块代码 ###
# 您需要移除下面的三重引号以进行测试
"""
print(f"莎士比亚的角色长度: {len(data)}")
print(f"词汇表看起来像这样: {''.join(vocab)}\n")
print(f"词汇表大小: {vocab_size}")
print(f"编码标记: {encoded_tokens}")
print(f"解码文本: {decoded_text}")
"""
### 测试结果: ###
"""
莎士比亚的角色长度: 1115394
词汇表看起来像这样:
!',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz<|begin_of_text|><|end_of_text|><|pad_id|>
词汇表大小: 68
编码标记: [20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42]
解码文本: Hello World
"""
步骤 2:解码器块
如果你查看上面的架构图,解码器块由以下子组件组成。
让我们逐个深入探讨这些子组件。
2a. RMS 范数(均方根归一化):
为什么需要RMSNorm? 在上面的架构图中,您一定注意到输入块的输出,即嵌入向量,经过RMSNorm块。这是因为嵌入向量有很多维度(在Llama3-8b中为4096维),并且总是有可能存在不同范围的值。这可能导致模型梯度爆炸或消失,从而导致收敛缓慢甚至发散。RMSNorm将这些值带入一定范围,有助于稳定和加速训练过程。这使得梯度具有更一致的幅度,从而使模型更快收敛。
RMSNorm是如何工作的? 让我们先看一下以下图表。
[图片由作者提供]: RMSNorm在形状为[3,3]的输入嵌入上的实现
例子:让我们将 RMSNorm 应用于第一个 token X1 的嵌入:
为什么选择RMSNorm而不是层归一化? 正如您在上面的例子中注意到的,我们没有计算任何均值或方差,而这是层归一化中所做的。因此,我们可以说RMSNorm通过避免均值和方差的计算来减少计算开销。此外,根据作者的论文,RMSNorm在不损失准确性的情况下提供了性能优势。
让我们实现 RMSNorm:
# 第2步:解码器块
# 注意:由于 Llama 3 模型是由 Meta 开发的,因此为了与他们的代码库保持同步并为将来的兼容性,我将使用 Meta GitHub 上的大部分代码,并进行一些必要的更改以实现我们的目标。
# 定义参数数据类:我们将在模型构建、训练和推理过程中使用这些参数。
# 注意:由于我们希望更快地看到训练和推理的结果,而不是专注于高准确性,因此我们对大多数设置在 Llama 3 模型中的参数选择较低的值。
@dataclass
class ModelArgs:
dim: int = 512 # 嵌入维度
n_layers: int = 8 # 模型解码器块的数量
n_heads: int = 8 # 查询嵌入的头数
n_kv_heads: int = 4 # 键和值嵌入的头数
vocab_size: int = len(vocab) # 词汇表长度
multiple_of: int = 256 # 计算前馈网络维度所需
ffn_dim_multiplier: Optional[float] = None # 计算前馈网络维度所需
norm_eps: float = 1e-5 # 为 RMSNorm 计算设置的默认 Epsilon 值
rope_theta: float = 10000.0 # RePE 计算的默认 theta 值
max_batch_size: int = 10 # 最大批处理大小
max_seq_len: int = 256 # 最大序列长度
epochs: int = 2500 # 训练迭代的总次数
log_interval: int = 10 # 打印日志和损失值的间隔数量
device: str = 'cuda' if torch.cuda.is_available() else 'cpu' # 根据可用性将设备分配为 cuda 或 cpu
## Step2a: RMSNorm
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
device = ModelArgs.device
self.eps = eps
# 缩放参数 gamma,初始化为 1,参数数量等于 dim 的大小
self.weight = nn.Parameter(torch.ones(dim).to(device))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps).to(device)
def forward(self, x):
# 形状: x[bs,seq,dim]
output = self._norm(x.float()).type_as(x)
# 形状: x[bs,seq,dim] -> x_norm[bs,seq,dim]
return output * self.weight
### 测试: RMSNorm 代码 ###
# 你需要去掉下面的三重引号来进行测试
"""
x = torch.randn((ModelArgs.max_batch_size, ModelArgs.max_seq_len, ModelArgs.dim), device=device)
rms_norm = RMSNorm(dim=ModelArgs.dim)
x_norm = rms_norm(x)
print(f"Shape of x: {x.shape}")
print(f"Shape of x_norm: {x_norm.shape}")
"""
### 测试结果: ###
"""
Shape of x: torch.Size([10, 256, 512])
Shape of x_norm: torch.Size([10, 256, 512])
"""
2b. 旋转位置编码 (RoPE):
我们为什么需要旋转位置编码(RoPE)? 在我们讨论为什么之前,让我们回顾一下迄今为止所做的工作。首先,我们将输入文本转换为嵌入。接下来,我们对嵌入应用了RMSNorm。此时,您一定注意到有些不对劲。假设输入文本是“我爱苹果”或“苹果爱我”,模型仍然会将这两个句子视为相同并学习为相同。因为在嵌入中没有定义顺序供模型学习。因此,顺序对于任何语言模型来说都是非常重要的。在Llama 3模型架构中,RePE用于定义句子中每个标记的位置,这不仅保持了顺序,还保持了句子中标记的相对位置。
那么,什么是旋转位置编码,如何工作? 正如上面“为什么”部分所提到的,RoPE是一种位置编码,它通过添加绝对位置信息来编码嵌入,保持句子中标记的顺序,同时结合标记之间的相对位置信息。它通过将给定的嵌入旋转一个叫做旋转矩阵的特殊矩阵来执行编码操作。这种简单但非常强大的数学推导使用旋转矩阵是RoPE的核心。
[作者图片]: 应用于二维向量的旋转矩阵
上面图中的旋转矩阵旋转一个二维向量。然而,Llama 3模型的维度数量是4096,这要多得多。让我们看看如何在高维嵌入上应用旋转。
[Image by writer]: RoPE 实现到嵌入的示例
我们现在知道,嵌入的旋转涉及将每个嵌入位置 (m) 值与每对嵌入维度的 theta (θ) 相乘。这就是 RoPE 如何通过旋转矩阵的实现捕获绝对位置以及相对位置信息的方式。
注意:旋转矩阵需要转换为极坐标形式,嵌入向量需要转换为复数,然后才能执行旋转。旋转完成后,旋转后的嵌入需要转换回实数,以便进行注意力操作。此外,RoPE 仅应用于查询和键嵌入。它不适用于值嵌入。
让我们深入了解 RoPE 编码:
## Step2b: RoPE
def precompute_freqs_cis(dim:int, seq_len: int, theta: float=10000.0):
# 计算每个维度对的Theta值,取dim/2
device = ModelArgs.device
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2,device=device)[:(dim//2)].float()/dim))
# 计算序列中的位置范围(m)
t = torch.arange(seq_len, dtype=torch.float32, device=device)
# freqs给出了序列中所有标记位置的Theta值范围
freqs = torch.outer(t, freqs).to(device)
# 这是旋转矩阵,需转换为极坐标形式,以执行嵌入的旋转
freqs_cis = torch.polar(torch.ones_like(freqs).to(device), freqs).to(device)
return freqs_cis
def reshape_for_broadcast(freqs_cis, x):
ndim = x.ndim
assert 0<=1<ndim
assert freqs_cis.shape == (x.shape[1],x.shape[-1]), "freqs_cis和x的最后两个维度必须匹配"
shape = [d if i==1 or i==ndim-1 else 1 for i,d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor)->Tuple[torch.Tensor, torch.Tensor]:
device = ModelArgs.device
# 将旋转位置编码同时应用于查询和键嵌入
# 首先:xq和xk嵌入的最后一个维度需要调整形状,以形成一对。因为旋转矩阵是应用于每对维度的。
# 接下来:将xq和xk转换为复数,因为旋转矩阵仅适用于复数
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)).to(device) #xq_:[bsz, seq_len, n_heads, head_dim/2]
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)).to(device) #xk_:[bsz, seq_len, n_heads, head_dim/2]
# 旋转矩阵(freqs_cis)在序列长度(dim=1)和头维度(dim=3)的维度应与嵌入匹配
# 此外,freqs_cis的形状应与xq和xk相同,因此将freqs_cis的形状更改为:[seq_len,head_dim] -> freqs_cis:[1,seq_len,1,head_dim]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
#最后,通过与freqs_cis相乘执行旋转操作。
#旋转完成后,将xq_out和xk_out转换回实数并返回
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).to(device) #xq_out:[bsz, seq_len, n_heads, head_dim]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).to(device) #xk_out:[bsz, seq_len, n_heads, head_dim]
return xq_out.type_as(xq), xk_out.type_as(xk)
### 测试:RoPE代码 ###
# 注意:x_norm是在RMSNorm期间计算的,并在此处用于测试。
# 您需要去掉下面的三重引号以执行测试
"""
head_dim = ModelArgs.dim//ModelArgs.n_heads
wq = nn.Linear(ModelArgs.dim, ModelArgs.n_heads * head_dim, bias=False, device=device)
wk = nn.Linear(ModelArgs.dim, ModelArgs.n_kv_heads * head_dim, bias=False, device=device)
xq = wq(x_norm)
xk = wk(x_norm)
print(f"xq.shape: {xq.shape}")
print(f"xk.shape: {xk.shape}")
xq = xq.view(xq.shape[0],xq.shape[1],ModelArgs.n_heads, head_dim)
xk = xk.view(xk.shape[0],xk.shape[1],ModelArgs.n_kv_heads, head_dim)
print(f"xq.re-shape: {xq.shape}")
print(f"xk.re-shape: {xk.shape}")
freqs_cis = precompute_freqs_cis(dim=head_dim, seq_len=ModelArgs.max_seq_len)
print(f"freqs_cis.shape: {freqs_cis.shape}")
xq_rotate, xk_rotate = apply_rotary_emb(xq, xk, freqs_cis)
print(f"xq_rotate.shape: {xq_rotate.shape}")
print(f"xk_rotate.shape: {xk_rotate.shape}")
"""
### 测试结果: ###
"""
xq.shape: torch.Size([10, 256, 512])
xk.shape: torch.Size([10, 256, 256])
xq.re-shape: torch.Size([10, 256, 8, 64])
xk.re-shape: torch.Size([10, 256, 4, 64])
freqs_cis.shape: torch.Size([256, 32])
xq_rotate.shape: torch.Size([10, 256, 8, 64])
xk_rotate.shape: torch.Size([10, 256, 4, 64])
"""
2c. KV 缓存(仅在推理时需要):
什么是KV-Cache? 在Llama 3架构中,在推理时,引入了KV-Cache的概念,用于以键和值缓存的形式存储先前生成的令牌。这些缓存将用于计算自注意力以生成下一个令牌。只有键和值令牌被缓存,而查询令牌不被缓存,因此称为KV缓存。
我们为什么需要KV缓存? 让我们看看下面的图来澄清我们的好奇心。
[作者图片]: KV缓存实现
2d. 组查询注意力:
组查询注意力与之前模型中使用的多头注意力相同,例如Llama 1,唯一的区别在于查询使用单独的头,而键/值使用单独的头。通常,分配给查询的头的数量是键、值头数量的n倍。让我们看看图表,以进一步加深我们的理解。
[作者图像]: 组查询注意力和多头注意力
在给定的图中,多头注意力在所有查询、键和值之间具有相等数量的头,n_heads = 8\。
组查询注意力模块有8个查询头(n_heads)和4个键值头(n_kv_heads),是查询头数量的2倍。
由于多头注意力已经如此优秀,我们为什么还需要组查询注意力? 要回答这个问题,我们需要暂时回到KV缓存。KV缓存大大减少了计算资源。然而,当KV缓存存储越来越多的以前的标记时,内存资源将显著增加。从模型性能和经济的角度来看,这不是一个好事。因此,引入了组查询注意力。 通过减少K和V的头数,它减少了要存储的参数数量,从而使用更少的内存。各种测试结果证明,这种方法下模型的准确性保持在相同的范围内。
让我们用代码来实现这个:
## 注意力块 [步骤2c: KV缓存;步骤2d: 分组查询注意力]
## 如前所述,命名约定遵循原始meta的LLama3 GitHub
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
# 嵌入维度
self.dim = args.dim
# 分配给查询的头数
self.n_heads = args.n_heads
# 分配给键和值的头数。如果为“None”,则数量将与查询相同。
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
# 相对于模型维度的每个头的维度
self.head_dim = args.dim // args.n_heads
# 重复次数以使键和值的头数与查询头数匹配
self.n_rep = args.n_heads // args.n_kv_heads
# 初始化键、查询、值和输出的权重。注意,对于q和kv的权重,out_feature值是基于头数量的
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False, device=device)
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False, device=device)
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False, device=device)
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False, device=device)
# 初始化缓存以存储键、值的开始。(KV缓存实现)
self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)
def forward(self, x: torch.Tensor, start_pos, inference):
# 输入嵌入的形状: [bsz,seq_len,dim]
bsz, seq_len, _ = x.shape
# 在“训练”期间使用mask,因KV缓存的使用在“推理”中不需要。
mask = None
xq = self.wq(x) #x[bsz,seq_len,dim]*wq[dim,n_heads * head_dim] -> q[bsz,seq_len,n_heads * head_dim]
xk = self.wk(x) #x[bsz,seq_len,dim]*wq[dim,n_kv_heads * head_dim] -> k[bsz,seq_len,n_kv_heads * head_dim]
xv = self.wv(x) #x[bsz,seq_len,dim]*wq[dim,n_kv_heads * head_dim] -> v[bsz,seq_len,n_kv_heads * head_dim]
# 根据头数重塑查询、键和值。 (分组查询注意力实现)
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) #xq[bsz,seq_len,n_heads, head_dim]
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) #xk[bsz,seq_len,n_kv_heads, head_dim]
xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) #xv[bsz,seq_len,n_kv_heads, head_dim]
# 模型 - 推理模式: kv-cache仅在推理模式下启用。
if inference:
# 计算序列中每个位置的旋转矩阵
freqs_cis = precompute_freqs_cis(dim=self.head_dim, seq_len=self.args.max_seq_len * 2)
# 在推理时,我们只应取旋转矩阵的范围,从当前token的位置开始。
freqs_cis = freqs_cis[start_pos : start_pos + seq_len]
# 将RoPE应用于查询和键嵌入
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
# 将键和值的token嵌入存储到各自的缓存中 [KV缓存实现]
self.cache_k[:bsz, start_pos:start_pos + seq_len] = xk
self.cache_v[:bsz, start_pos:start_pos + seq_len] = xv
# 将当前token位置之前所有token的嵌入分配给键和值变量,以进行注意力计算
keys = self.cache_k[:bsz, :start_pos + seq_len]
values = self.cache_v[:bsz, :start_pos + seq_len]
# 此时,键和值的形状与必须相同的查询嵌入不一样,以便计算注意力分数
# 使用repeat_kv函数使键、值的形状与查询的形状相同
keys = repeat_kv(keys, self.n_rep) #keys[bsz,seq_len,n_heads,head_dim]
values = repeat_kv(values, self.n_rep) #values[bsz,seq_len,n_heads,head_dim]
# 模式 - 训练模式: 未实现KV缓存
else:
# 计算旋转矩阵并将RoPE应用于查询和键以进行训练。
freqs_cis = precompute_freqs_cis(dim=self.head_dim, seq_len=self.args.max_seq_len)
#xq[bsz,seq_len,n_heads, head_dim], xk[bsz,seq_len,n_heads, head_dim]
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
# 使用repeat_kv函数使键、值的形状与查询的形状相同
#keys[bsz,seq_len,n_heads,head_dim], #values[bsz,seq_len,n_heads,head_dim]
keys = repeat_kv(xk, self.n_rep)
values = repeat_kv(xv, self.n_rep)
# 对于训练模式,我们稍后将计算mask并应用于注意力分数
mask = torch.full((seq_len, seq_len),float("-inf"),device=self.args.device)
mask = torch.triu(mask, diagonal=1).to(self.args.device)
# 为了计算注意力,我们需要执行转置操作以重塑所有查询、键和值,将头置于dim 1,将序列置于dim 2
xq = xq.transpose(1,2) #xq[bsz,n_heads,seq_len,head_dim]
keys = keys.transpose(1,2) #keys[bsz,n_heads,seq_len,head_dim]
values = values.transpose(1,2) #values[bsz,n_heads,seq_len,head_dim]
# 计算注意力分数
scores = torch.matmul(xq, keys.transpose(2,3)).to(self.args.device)/math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask
# 对注意力分数应用softmax
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
# 将注意力分数与值做矩阵乘法
output = torch.matmul(scores, values).to(self.args.device)
# 我们得到每个头的上下文嵌入
# 所有头需要重新形状并结合以给出单一的上下文注意力输出
# 形状变化: output[bsz,n_heads,seq_len,head_dim] -> output[bsz,seq_len, n_heads,head_dim] -> output[bsz,seq_len, n_heads * head_dim]
output = output.transpose(1,2).contiguous().view(bsz, seq_len, -1)
# 形状: output [bsz,seq_len,dim]
return self.wo(output)
# 如果键/值头的数量少于查询头的数量,则此函数将所需的重复次数扩展键/值嵌入
def repeat_kv(x:torch.Tensor, n_rep: int)-> torch.Tensor:
bsz, seq_len, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:,:,:,None,:]
.expand(bsz,seq_len,n_kv_heads,n_rep, head_dim)
.reshape(bsz,seq_len,n_kv_heads * n_rep, head_dim)
)
### 测试: Repeat_kv函数 ###
# 注意: xk, x_norm已在RoPE、RMSNorm测试期间计算,并在此处用作测试。
# 你需要去掉下面的三重引号以进行测试
"""
n_rep = ModelArgs.n_heads // ModelArgs.n_kv_heads
keys = repeat_kv(xk, n_rep)
print(f"xk.shape: {xk.shape}")
print(f"keys.shape: {keys.shape}")
## 测试: 注意力函数
# 你需要去掉下面的三重引号以进行测试
attention = Attention(ModelArgs)
x_out = attention(x_norm,start_pos=0, inference=False)
print(f"x_out.shape: {x_out.shape}")
"""
### 测试结果: ###
"""
xk.shape: torch.Size([10, 256, 4, 64])
keys.shape: torch.Size([10, 256, 8, 64])
x_out.shape: torch.Size([10, 256, 512])
"""
2e. 前馈网络 (SwiGLU 激活):
解码器块中的前馈网络做什么? 如上面的架构图所示,注意力输出首先在 RMSNorm 中进行归一化,然后输入到前馈网络中。在前馈网络内部,注意力输出的嵌入将在其隐藏层中扩展到更高的维度,并学习更复杂的令牌特征。
为什么使用SwiGLU而不是ReLU? 让我们看看图表以获取答案。
[Image by writer]: 具有SwiGLU函数的前馈网络
如上图所示,SwiGLU函数在正轴上的表现几乎和ReLU相同。然而,在负轴上,SwiGLU输出一些负值,这在学习小值而不是ReLU情况下的平坦0时可能是有用的。总体而言,根据作者的说法,使用SwiGLU的性能优于使用ReLU,因此被选中。
让我们深入了解FeedForward代码:
## Step2e: 前馈网络 (SwiGLU 激活)
class FeedForward(nn.Module):
def __init__(self, dim:int, hidden_dim:int, multiple_of:int, ffn_dim_multiplier: Optional[float]):
super().__init__()
# 模型嵌入维度
self.dim = dim
# 我们必须使用 Meta 共享的隐藏维度计算,这是该模型的理想选择
# 隐藏维度计算为256的倍数。
hidden_dim = int(2 * hidden_dim/3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
# 定义隐藏层权重
self.w1 = nn.Linear(self.dim, hidden_dim, bias=False, device=device)
self.w2 = nn.Linear(hidden_dim, self.dim, bias=False, device=device)
self.w3 = nn.Linear(self.dim, hidden_dim, bias=False, device=device)
def forward(self, x):
# 形状: [bsz, seq_len, dim]
return self.w2(F.silu(self.w1(x)) * self.w3(x))
### 测试: FeedForward 模块 ###
# 注意: x_out 已在 Attention 测试中计算,并在此处用于测试。
# 您需要去掉下面的三重引号以执行测试
"""
feed_forward = FeedForward(ModelArgs.dim, 4 * ModelArgs.dim, ModelArgs.multiple_of, ModelArgs.ffn_dim_multiplier)
x_out = rms_norm(x_out)
x_out = feed_forward(x_out)
print(f"前馈输出: x_out.shape: {x_out.shape}")
"""
### 测试结果: ###
"""
前馈输出: x_out.shape: torch.Size([10, 256, 512])
"""
2f. 解码器块:
如上图所示(最初的图表)。解码器块由多个子组件组成,这些组件我们在之前的部分(2a - 2f)中学习和编码过。下面是解码器块内部进行的逐点操作。
- 输入块的嵌入被送入 Attention-RMSNorm 块。这将进一步送入 Group Query Attention 块。
- 输入块中的相同嵌入将被添加到注意力输出中。
- 之后,注意力输出被输入到前馈-RMSNorm,并进一步输入到前馈网络块中。
- 然后将前馈网络的输出与注意力输出再次相加。
- 生成的输出称为 解码器输出。 该解码器输出随后作为输入传入另一个解码器块。这个操作将对接下来的31个解码器块重复进行。第32个解码器块的最终解码器输出将传递到输出块。
让我们在下面的代码中看看这个动作:
## Step2f: 解码器块。类名被赋值为 TransformerBlock,以匹配 Meta llama 3 代码库的名称。
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
# 初始化 RMSNorm 用于注意力
self.attention_norm = RMSNorm(dim=args.dim, eps = args.norm_eps)
# 初始化注意力类
self.attention = Attention(args)
# 初始化 RMSNorm 用于前馈类
self.ff_norm = RMSNorm(dim=args.dim, eps = args.norm_eps)
# 初始化前馈类
self.feedforward = FeedForward(args.dim, 4 * args.dim, args.multiple_of, args.ffn_dim_multiplier)
def forward(self, x, start_pos, inference):
# start_pos = 推理模式的标记位置,inference = True 表示推理,False 表示训练模式
# i) 将输入嵌入传递给 attention_norm,然后传递给注意力块。
# ii) 注意力的输出然后加到嵌入上(在归一化之前)
h = x + self.attention(self.attention_norm(x), start_pos, inference)
# i) 将注意力输出传递给 ff_norm,然后传递给前馈网络。
# ii) 前馈网络的输出然后加到注意力输出上(在 ff_norm 之前)
out = h + self.feedforward(self.ff_norm(h))
# 形状: [bsz,seq_len,dim]
return out
### 测试: TransformerBlock ###
# 您需要去掉下面的三重引号以进行测试
"""
x = torch.randn((ModelArgs.max_batch_size, ModelArgs.max_seq_len, ModelArgs.dim), device=device)
transformer_block = TransformerBlock(ModelArgs)
transformer_block_out = transformer_block(x,start_pos=0, inference=False)
print(f"transformer_block_out.shape: {transformer_block_out.shape}")
"""
### 测试结果: ###
"""
transformer_block_out.shape: torch.Size([10, 64, 128])
"""
第3步:输出块
最终解码器块的解码器输出将输入到输出块。它首先被输入到RMSNorm。然后,它将输入到生成logits的线性层。接下来,会发生以下两种操作之一。
让我们看一下输出块流程图以获得更多清晰度。
[作者提供的图片]: LLama 3 的训练和推断模式输出流程图
最后,让我们将3个模块(输入模块、解码器模块和输出模块)结合在一起。这就是我们的最终Llama 3模型。
让我们编写最终的Llama 3模型:
## Step3: 输出块
# 这是 Llama 3 模型。同样,类名保持为 Transformer,以与 Meta Llama 3 模型匹配。
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
# 设置 params 变量中的所有 ModelArgs
self.params = params
# 从输入块初始化嵌入类
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
# 初始化解码块并将其存储在 ModuleList 中。
# 这是因为我们在 Llama 3 模型中有 4 个解码块。(官方的 Llama 3 有 32 个块)
self.layers = nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(args=params))
# 为输出块初始化 RMSNorm
self.norm = RMSNorm(params.dim, eps = params.norm_eps)
# 在输出块初始化线性层。
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
def forward(self, x, start_pos=0, targets=None):
# start_pos = 推断模式下的 token 位置,推断 = True 表示推断,False 表示训练模式
# x 是使用分词器从文本或提示生成的 token_ids 的批次。
# x[bsz, seq_len] -> h[bsz, seq_len, dim]
h = self.tok_embeddings(x)
# 如果目标为 None,则激活推断模式, 如果激活训练模式则设置为 "False"。
if targets is None:
inference = True
else:
inference = False
# 嵌入 (h) 将通过所有解码块。
for layer in self.layers:
h = layer(h, start_pos, inference)
# 最后一个解码块的输出将输入到 RMSNorm
h = self.norm(h)
# 正常化后,嵌入 h 将输入到线性层中。
# 线性层的主要任务是生成将嵌入与词汇大小映射的 logits。
# h[bsz, seq_len, dim] -> logits[bsz, seq_len, vocab_size]
logits = self.output(h).float()
loss = None
# 如果没有可用的目标,则激活推断模式
if targets is None:
loss = None
# 如果目标可用,则激活训练模式。损失将被计算用于进一步的模型训练。
else:
loss = F.cross_entropy(logits.view(-1, self.params.vocab_size), targets.view(-1))
return logits, loss
### 测试: Transformer (Llama模型) ###
# 你需要去掉下面的三重引号以进行测试
"""
model = Transformer(ModelArgs).to(ModelArgs.device)
print(model)
"""
[Image by Write]: LLama 3 层次架构
我们刚刚构建的Llama 3模型看起来完美。我们现在准备好开始我们的训练过程。
第 4 步:训练我们的 Llama 3 模型:
输出块流程图(步骤3)中提供了训练流程。如果您想在开始训练之前更清楚地了解,请再次参考该流程。让我们开始编写训练代码。我还将在代码块中提供必要的解释。
## 第4步:训练Llama 3模型:
# 使用我们在输入块部分构建的tokenizer的encode函数对整个tiny_shakespeare数据的token_ids列表进行编码,创建一个数据集
dataset = torch.tensor(encode(data), dtype=torch.int).to(ModelArgs.device)
print(f"dataset-shape: {dataset.shape}")
# 定义函数从给定数据集中生成批次
def get_dataset_batch(data, split, args:ModelArgs):
seq_len = args.max_seq_len
batch_size = args.max_batch_size
device = args.device
train = data[:int(0.8 * len(data))]
val = data[int(0.8 * len(data)): int(0.9 * len(data))]
test = data[int(0.9 * len(data)):]
batch_data = train
if split == "val":
batch_data = val
if split == "test":
batch_data = test
# 从数据集中随机选择起始点,以提供随机样本用于训练、验证和测试。
ix = torch.randint(0, len(batch_data) - seq_len - 3, (batch_size,)).to(device)
x = torch.stack([torch.cat([token_bos, batch_data[i:i+seq_len-1]]) for i in ix]).long().to(device)
y = torch.stack([torch.cat([batch_data[i+1:i+seq_len], token_eos]) for i in ix]).long().to(device)
return x,y
### 测试:get_dataset函数 ###
"""
xs, ys = get_dataset_batch(dataset, split="train", args=ModelArgs)
print([(decode(xs[i].tolist()), decode(ys[i].tolist())) for i in range(len(xs))])
"""
# 定义一个评估损失函数,以计算和存储训练和验证损失以供记录和绘图
@torch.no_grad()
def evaluate_loss(model, args:ModelArgs):
out = {}
model.eval()
for split in ["train", "val"]:
losses = []
for _ in range(10):
xb, yb = get_dataset_batch(dataset, split, args)
_, loss = model(x=xb, targets=yb)
losses.append(loss.item())
out[split] = np.mean(losses)
model.train()
return out
# 定义一个训练函数来执行模型训练
def train(model, optimizer, args:ModelArgs):
epochs = args.epochs
log_interval = args.log_interval
device = args.device
losses = []
start_time = time.time()
for epoch in range(epochs):
optimizer.zero_grad()
xs, ys = get_dataset_batch(dataset, 'train', args)
xs = xs.to(device)
ys = ys.to(device)
logits, loss = model(x=xs, targets=ys)
loss.backward()
optimizer.step()
if epoch % log_interval == 0:
batch_time = time.time() - start_time
x = evaluate_loss(model, args)
losses += [x]
print(f"Epoch {epoch} | val loss {x['val']:.3f} | Time {batch_time:.3f}")
start_time = time.time()
# 打印最终验证损失
print("验证损失:", losses[-1]['val'])
# 在图中显示各个间隔的损失
return pd.DataFrame(losses).plot()
现在,我们已经定义了训练函数。让我们通过以下代码块开始训练,并在训练完成后观察图中的训练结果。
## 开始训练我们的 Llama 3 模型
model = Transformer(ModelArgs).to(ModelArgs.device)
optimizer = torch.optim.Adam(model.parameters())
train(model, optimizer, ModelArgs)
[image by writer]: 训练与验证损失图表
以上图像展示了训练和验证损失图。训练已进行2500个周期。使用Google Colab的默认GPU和RAM设置,训练过程大约花费了10分钟,非常快。最后一个周期的验证损失为2.19,考虑到我们使用的训练数据量和周期数量,这被认为是可以接受的。为了显著降低损失,我们将需要增加训练数据的规模、增加周期数量以及更高的GPU或处理能力。
现在我们已经完成了训练。让我们进入最后一步 — 推理,看看模型在给定新输入提示时生成输出文本的效果如何。
第5步:推理Llama 3模型:
推理流程在输出块流程图中提供(步骤 3)。让我们开始编写推理代码。
## 第5步:推理Llama 3模型:
# 此函数根据提供的提示生成基于我们构建和训练的LLama 3模型的文本序列。
def generate(model, prompts: str, params: ModelArgs, max_gen_len: int=500, temperature: float = 0.6, top_p: float = 0.9):
# prompt_tokens: 用户输入文本或提示的列表
# max_gen_len: 生成的文本序列的最大长度。
# temperature: 控制采样随机性的温度值。默认为0.6。
# top_p: 从logits中采样prob输出的top-p概率阈值。默认为0.9。
# prompt_tokens = [0]
bsz = 1 #对于推理,一般用户只输入一个提示,我们将其视为1个批次
prompt_tokens = token_bos.tolist() + encode(prompts)
assert len(prompt_tokens) <= params.max_seq_len, "提示标记长度应小于最大序列长度"
total_len = min(len(prompt_tokens)+max_gen_len, params.max_seq_len)
# 此标记矩阵用于存储输入提示和模型生成的所有输出。
# 稍后我们将使用tokenizers解码功能解码这些标记,以文本格式查看结果
tokens = torch.full((bsz,total_len), fill_value=token_pad.item(), dtype=torch.long, device=params.device)
# 将提示标记填入标记矩阵
tokens[:,:len(prompt_tokens)] = torch.tensor(prompt_tokens, dtype=torch.long, device=params.device)
# 创建一个prompt_mask_token以便后续使用,以识别标记是提示标记还是填充标记
# 如果是提示标记则为True,如果是填充标记则为False
input_text_mask = tokens != token_pad.item()
# 现在我们可以开始使用提示标记列表中的一个标记进行推理,从第一个位置开始。
prev_pos = 0
for cur_pos in range(1, total_len):
with torch.no_grad():
logits, _ = model(x=tokens[:,prev_pos:cur_pos], start_pos=prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1]/temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# 仅在它是填充标记时替换标记
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
prev_pos = cur_pos
if tokens[:,cur_pos]==token_pad.item() and next_token == token_eos.item():
break
output_tokens, output_texts = [], []
for i, toks in enumerate(tokens.tolist()):
# eos_idx = toks.index(token_eos.item())
if token_eos.item() in toks:
eos_idx = toks.index(token_eos.item())
toks = toks[:eos_idx]
output_tokens.append(toks)
output_texts.append(decode(toks))
return output_tokens, output_texts
# 在概率分布上执行top-p(核)采样。
# probs (torch.Tensor): 从logits派生的概率分布张量。
# p: top-p采样的概率阈值。
# 根据论文,Top-p采样选择总概率质量超过阈值p的最小标记集。
# 该分布基于所选标记重新归一化。
def sample_top_p(probs, p):
probs_sort, prob_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(prob_idx, -1, next_token)
# 返回从词汇中抽样的标记索引
return next_token
让我们对新的提示进行推理并检查生成的输出
## 对用户输入的提示进行推理
prompts = "考虑他做了什么服务"
output_tokens, output_texts = generate(model, prompts, ModelArgs)
output_texts = output_texts[0].replace("<|begin_of_text|>", "")
print(output_texts)
## 输出 ##
"""
考虑他做了什么服务 o eretrane
adetranytnn i eey i ade hs rcuh i eey,ad hsatsTns rpae,T
eon o i hseflns o i eee ee hs ote i ocal ersl,Bnnlnface
o i hmr a il nwye ademto nt i a ere
h i ees.
Frm oe o etrane o oregae,alh,t orede i oeral
"""
而且是的,我们可以看到我们的Llama 3模型能够在新的提示上进行推理并生成文本,尽管考虑到我们用于训练的训练数据和训练轮次,输出似乎并不理想。我相信通过更大的训练数据,我们会获得更好的准确性。
这就是了! 我们成功地从头开始构建了自己的 Llama 3 模型。我们还成功地训练了该模型,并且成功地进行了推理,在使用 Google Colab Notebook 提供的免费 GPU 和 RAM 的短时间内生成了新文本。如果你一路跟随到这里,我要亲自祝贺你所付出的巨大努力。
我的最终想法
Llama 3及其其他变体是当前LLM领域最受欢迎的开源LLM。我相信,从头开始构建Llama 3的能力提供了构建许多新兴激动人心的基于LLM的应用程序所需的所有基础。我真心相信知识应该是人人皆可自由获取的,请随意使用源代码并进行更新,以构建您的个人或专业项目。祝大家好运。
非常感谢您的阅读!
链接到 Google Colab 笔记本
https://github.com/tamangmilan/llama3/blob/main/build_llama3_from_scratch.ipynb
相关推荐
- 从IDEA开始,迈进GO语言之门(idea got)
-
前言笔者在学习GO语言编程的时候,GO语言在国内还没有像JAVA/Php/Python那样普及,绕了不少的弯路,要开始入门学习一门编程语言,最好就先从选择一个好的编程语言的开发环境开始,有了这个开发环...
- 基于SpringBoot+MyBatis的私人影院java网上购票jsp源代码Mysql
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目介绍基于SpringBoot...
- 基于springboot的个人服装管理系统java网上商城jsp源代码mysql
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目介绍基于springboot...
- 基于springboot的美食网站Java食品销售jsp源代码Mysql
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目介绍基于springboot...
- 贸易管理进销存springboot云管货管账分析java jsp源代码mysql
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目描述贸易管理进销存spring...
- SpringBoot+VUE员工信息管理系统Java人员管理jsp源代码Mysql
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目介绍SpringBoot+V...
- 目前见过最牛的一个SpringBoot商城项目(附源码)还有人没用过吗
-
帮粉丝找了一个基于SpringBoot的天猫商城项目,快速部署运行,所用技术:MySQL,Druid,Log4j2,Maven,Echarts,Bootstrap...免费给大家分享出来前台演示...
- SpringBoot+Mysql实现的手机商城附带源码演示导入视频
-
今天为大家带来的是基于SpringBoot+JPA+Thymeleaf框架的手机商城管理系统,商城系统分为前台和后台、前台用的是Bootstrap框架后台用的是SpringBoot+JPA都是现在主...
- 全网首发!马士兵内部共享—1658页《Java面试突击核心讲》
-
又是一年一度的“金九银十”秋招大热门,为助力广大程序员朋友“面试造火箭”,小编今天给大家分享的便是这份马士兵内部的面试神技——1658页《Java面试突击核心讲》!...
- SpringBoot数据库操作的应用(springboot与数据库交互)
-
1.JDBC+HikariDataSource...
- SpringBoot 整合 Flink 实时同步 MySQL
-
1、需求在Flink发布SpringBoot打包的jar包能够实时同步MySQL表,做到原表进行新增、修改、删除的时候目标表都能对应同步。...
- SpringBoot + Mybatis + Shiro + mysql + redis智能平台源码分享
-
后端技术栈基于SpringBoot+Mybatis+Shiro+mysql+redis构建的智慧云智能教育平台基于数据驱动视图的理念封装element-ui,即使没有vue的使...
- Springboot+Mysql舞蹈课程在线预约系统源码附带视频运行教程
-
今天发布的是由【猿来入此】的优秀学员独立做的一个基于springboot脚手架的Springboot+Mysql舞蹈课程在线预约系统,系统项目源代码在【猿来入此】获取!https://www.yuan...
- SpringBoot+Mysql在线众筹系统源码+讲解视频+开发文档(参考论文
-
今天发布的是由【猿来入此】的优秀学员独立做的一个基于springboot脚手架的在线众筹管理系统,主要实现了普通用户在线参与众筹基本操作流程的全部功能,系统分普通用户、超级管理员等角色,除基础脚手架外...
- Docker一键部署 SpringBoot 应用的方法,贼快贼好用
-
这两天发现个Gradle插件,支持一键打包、推送Docker镜像。今天我们来讲讲这个插件,希望对大家有所帮助!GradleDockerPlugin简介...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- 从IDEA开始,迈进GO语言之门(idea got)
- 基于SpringBoot+MyBatis的私人影院java网上购票jsp源代码Mysql
- 基于springboot的个人服装管理系统java网上商城jsp源代码mysql
- 基于springboot的美食网站Java食品销售jsp源代码Mysql
- 贸易管理进销存springboot云管货管账分析java jsp源代码mysql
- SpringBoot+VUE员工信息管理系统Java人员管理jsp源代码Mysql
- 目前见过最牛的一个SpringBoot商城项目(附源码)还有人没用过吗
- SpringBoot+Mysql实现的手机商城附带源码演示导入视频
- 全网首发!马士兵内部共享—1658页《Java面试突击核心讲》
- SpringBoot数据库操作的应用(springboot与数据库交互)
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)