LLM - Generate With KV-Cache 图解与实践 By GPT-2
LLM Generate With KV-Cache 图解与实践 By GPT-2
目录
一.引言
LLM 推理中 KV-Cache 是最常见的优化方式,其通过缓存过去的 Keys、Values 从而提高 generate 每一个新 token 的速度,效果明显,是典型的空间换时间的做法,下面通过图示和 GPT-2 实测,看下 KV-Cache 的原理与实践。
二.KV-Cache 图解
1.Attention 计算
- MatMul-1 Q、K 负责计算当前 Token 与 候选 Token 之间的相似度
- Scale 防止 MatMul 值过大,对 MatMul 的值进行 Sqrt(d) 的缩放
- Mask Causal Mask 时前后 Token 存在逻辑关系,后面的 Token 权重为 0 或很小的数
- SoftMax 权重归一化
- MatMul-2 根据相似度加权平均获取当前 Attention 后的结果
上面的流程简化一下,可以看作是一次 '基于 QK 相似度对 V 的加权平均' 的操作:
2.Generate WithOut KV-Cache
KV Cache 用于推理过程,下面我们以生成 "遥遥领先" 为例示范:
- <s>
生成遥遥领先之前,需要先从起始符 <s> 开始,其遵循前面图中的 Attention 计算公式:
- <s>遥
当前字符为 "<s>遥" 由于 '遥' 在 '<s>' 的后面,所以对于 '<s>' 而言,'遥' 的向量 V 是不会对 '<s>' 的 Attention 结果有影响的,也就是说对于 '<s>' 而言,'遥' 的向量 softmax 后权重是一个极小的接近于 0 的数字,下面有计算过程:
计算的得到的 1x2 的矩阵 1 对应 Batch Size,2 对应 seq_len,还有一个隐含的向量维度 Dim:
由 Att1、Att2 我们可以看到:
- Att1 的生成需要 K1、V1
- Att2 的生成需要 K1、K2、V1、V2
- <s>遥遥
Att3 的生成需要 K1、K2、K3、V1、 V2、V3
- <s>遥遥领
Att4 的生成需要 K1、K2、K3、K4、V1、 V2、V3、V4
- <s>遥遥领先 ...
数学归纳法的原理是给定 F(0),F(1),再假设有 F(n-1) 看能否推出 F(n),
后续的生成过程就不再赘述了,根据 Casual Mask 的性质,不难得出:
AttN 的生成需要 K1、,,,、KN、V1、...、VN。
3.Generate With KV-Cache
- Output Probability
Generate 生成 Next Token 时基于 Attention 的最后一个结果,举个例子:
'<s>遥遥领' 已经生成,此时需要预测 Next Token,通过 Attention 计算得到 1 x 4 x Dim 的 Attention 矩阵,而预测最终 token 的概率计算只参考最后 Dim 维的向量,即上图红框标注的位置。
- Repeat Counting
基于 Output Probability 的计算流程,再观察最右侧的 Attention 计算结果,我们发现每次计算都有很多的冗余,其实我们只需要获取每一步最后一维的 Tensor 即可,但是我们每一步都在重复计算前面的部分,所以 KV-Cache 应运而生:
AttN 的生成需要 K1、,,,、KN、V1、...、VN
通过缓存每一步的 Keys 和 Values 即可实现高效的 Next Token 的推理。
- Generate Process
通过缓存每一步的 K、V 实现高效的推理,因为我们只需要计算最后一维的向量即可,付出的代价是显存的增加,其与我们生成的 Response Token Length 成线性正比关系。
4.Cache Memory Usage
继续看刚才的示例,此时我们有如下参数:
- batch_size 1 \
- seq_len 4 \
- dim emb_size \
由于 Decoder-Only 是多层堆叠的结构,所以还有一个潜在的参数:
- layer_num N \
由于 Multi-Head Attention 是按照 emb_size = head_num x head_dim,所以我们这里直接按照总的 emb_size 计算,不再拆分 head,按照 FP16 计算其缓存的通式:
memory_usage = 2 x bsz x seq_len x dim x layer_num x byte(FP16)
以 1 条样本、生成 512 长度、4096 维向量、32 层堆叠为例:
memory = 2 x 1 x 512 x 4096 x 32 x 2 = 268435456 / 1024 / 1024 = 256.0 MB
所以显存比较极限的场景下,也需要注意 KV-Cache 的显存占用,虽然是随着 SeqLen 线性增长的,但是架不住维度和堆叠的 Decoder 多。
Tips:
这里解释下计算时前后两个 2 怎么来的,第一个 2 是因为 KV cache 里 K/V 各缓存一次;第二个 2 是因为 FP16 16 位占用 2 个 byte。
三.KV-Cache 实践
接下来我们实践下 KV-Cache,由于是本地实验,所以采用比较小的 GPT-2 作为实验 LLM。
Generate 时主要通过 past_key_values 传递 KV-Cache:
1.WithOut KV-Cache
import time
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
def common(in_tokens, model, tokenizer, is_log=False):
# inference
token_eos = torch.tensor([198]) # line break symbol
out_token = None
i = 0
st = time.time()
with torch.no_grad():
while out_token != token_eos:
logits, _ = model(in_tokens)
out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
in_tokens = torch.cat((in_tokens, out_token), 0)
text = tokenizer.decode(in_tokens)
if is_log:
print(f'step {i} input: {text}', flush=True)
i += 1
end = time.time()
out_text = tokenizer.decode(in_tokens)
print(f'Input: {in_text}')
print(f'Output: {out_text}')
print(f"Total Cost: {end - st} Mean: {(end - st) / i}")
token_id = 198 为 GPT-2 的 <eos>,我们手动停止生成,这里可以看到每一个 token 的预测过程:
- logits 通过 model 计算 logits 概率
- argmax 通过 argmax 获取概率最大的 token
- input_tokens = in_tokens + new_token 持续追加 seq_len 长度
- text 通过 tokenizer decode 即可获取 token 转变后的字符
2.With KV-Cache
def cache(in_tokens, model, tokenizer, is_log=False):
# inference
token_eos = torch.tensor([198]) # line break symbol
out_token = None
kvcache = None
out_text = in_text
i = 0
st = time.time()
with torch.no_grad():
while out_token != token_eos:
logits, kvcache = model(in_tokens, past_key_values=kvcache) # 增加了一个 past_key_values 的参数
out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
in_tokens = out_token # 输出 token 直接作为下一轮的输入,不再拼接
text = tokenizer.decode(in_tokens)
if is_log:
print(f'step {i} input: {text}', flush=True)
i += 1
out_text += text
end = time.time()
print(f'Input: {in_text}')
print(f'Output: {out_text}')
print(f"Total Cost: {end - st} Mean: {(end - st) / i}")
Generate Process:
3.Compare Efficiency
if __name__ == '__main__':
local_path = "/LLM/model/gpt2"
model = GPT2LMHeadModel.from_pretrained(local_path, torchscript=True).eval()
# tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(local_path)
in_text = "Cristiano Ronaldo is a"
in_tokens = torch.tensor(tokenizer.encode(in_text))
common(in_tokens, model, tokenizer)
cache(in_tokens, model, tokenizer)
比较 common 和 cache 的生成效果和时间:
- Common
- Cache
- Efficient
生成的结果相同,Cache 只需 Common 耗时的 43% 左右即可完成相同的推理。
Tips:
由于本地测试且 seq 比较短,所以这里就不参考本机显存变化了,需要的话大家可以用前面公式计算一下,这里 GPT-2 的 dim = 768,layer_num = 12。
四.总结
Generate 流程、Attention 计算以及 KV-Cache 的流程大致就这么多,下面总结下:
- 注意 Scale
Attention 计算时有一个 Scale 的操作,图中没有标注,注意不要忘记。
- Generate
生成是一个 token 一个 token 生成的,后一个 token 基于前面的所有 token。
- Q Cache
KV-Cache,有的同学肯定有疑问,为啥不把 Q 也 Cache 了。因为:
AttN 的生成需要 K1、,,,、KN、V1、...、VN
还需要 QN,因此缓存 Q1、Q2 ... 对于计算 AttN 没有意义,而且每一个 Q 都是需要基于前面序列来生成。
- Use Cache
{
"architectures": [
"MistralForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 1,
...
"use_cache": true,
"vocab_size": 32000
}
当前新出的 LLM 模型都在 config 内置了 use_cache 参数,上面是 Mistral config 中的部分参数,KV-Cache 都是 infer 时默认开启的。
!! 最后感谢下面大佬们的输出:
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)