目录

一.引言

二.KV-Cache 图解

1.Attention 计算

2.Generate WithOut KV-Cache

3.Generate With KV-Cache

4.Cache Memory Usage

三.KV-Cache 实践

1.WithOut KV-Cache

2.With KV-Cache

3.Compare Efficiency

四.总结


一.引言

LLM 推理中 KV-Cache 是最常见的优化方式,其通过缓存过去的 Keys、Values 从而提高 generate 每一个新 token 的速度,效果明显,是典型的空间换时间的做法,下面通过图示和 GPT-2 实测,看下 KV-Cache 的原理与实践。

二.KV-Cache 图解

1.Attention 计算

- MatMul-1        Q、K 负责计算当前 Token 与 候选 Token 之间的相似度

- Scale        防止 MatMul 值过大,对 MatMul 的值进行 Sqrt(d) 的缩放

- Mask        Causal Mask 时前后 Token 存在逻辑关系,后面的 Token 权重为 0 或很小的数

- SoftMax        权重归一化

- MatMul-2        根据相似度加权平均获取当前 Attention 后的结果

上面的流程简化一下,可以看作是一次 '基于 QK 相似度对 V 的加权平均' 的操作:

2.Generate WithOut KV-Cache

KV Cache 用于推理过程,下面我们以生成 "遥遥领先" 为例示范:

- <s>

生成遥遥领先之前,需要先从起始符 <s> 开始,其遵循前面图中的 Attention 计算公式:

- <s>遥

当前字符为 "<s>遥" 由于 '遥' 在 '<s>' 的后面,所以对于 '<s>' 而言,'遥' 的向量 V 是不会对 '<s>' 的 Attention 结果有影响的,也就是说对于 '<s>' 而言,'遥' 的向量 softmax 后权重是一个极小的接近于 0 的数字,下面有计算过程:

计算的得到的 1x2 的矩阵 1 对应 Batch Size,2 对应 seq_len,还有一个隐含的向量维度 Dim:

由 Att1、Att2 我们可以看到:

- Att1 的生成需要 K1、V1

- Att2 的生成需要 K1、K2、V1、V2

- <s>遥遥

Att3 的生成需要 K1K2K3V1V2V3

- <s>遥遥领

Att4 的生成需要 K1K2K3K4V1V2V3、V4

- <s>遥遥领先 ...

数学归纳法的原理是给定 F(0),F(1),再假设有 F(n-1) 看能否推出 F(n),

后续的生成过程就不再赘述了,根据 Casual Mask 的性质,不难得出:

 AttN 的生成需要 K1、,,,、KNV1、...、VN

3.Generate With KV-Cache

- Output Probability

Generate 生成 Next Token 时基于 Attention 的最后一个结果,举个例子:

'<s>遥遥领' 已经生成,此时需要预测 Next Token,通过 Attention 计算得到 1 x 4 x Dim 的 Attention 矩阵,而预测最终 token 的概率计算只参考最后 Dim 维的向量,即上图红框标注的位置。

- Repeat Counting

基于 Output Probability 的计算流程,再观察最右侧的 Attention 计算结果,我们发现每次计算都有很多的冗余,其实我们只需要获取每一步最后一维的 Tensor 即可,但是我们每一步都在重复计算前面的部分,所以 KV-Cache 应运而生:

 AttN 的生成需要 K1、,,,、KNV1、...、VN

通过缓存每一步的 Keys 和 Values 即可实现高效的 Next Token 的推理。

- Generate Process

通过缓存每一步的 K、V 实现高效的推理,因为我们只需要计算最后一维的向量即可,付出的代价是显存的增加,其与我们生成的 Response Token Length 成线性正比关系。

4.Cache Memory Usage

继续看刚才的示例,此时我们有如下参数:

- batch_size 1 \

- seq_len 4 \

- dim emb_size \ 

由于 Decoder-Only 是多层堆叠的结构,所以还有一个潜在的参数:

- layer_num N \

由于 Multi-Head Attention 是按照 emb_size = head_num x head_dim,所以我们这里直接按照总的 emb_size 计算,不再拆分 head,按照 FP16 计算其缓存的通式:

memory_usage = 2 x bsz x seq_len x  dim x layer_num x byte(FP16)

以 1 条样本、生成 512 长度、4096 维向量、32 层堆叠为例:

memory = 2 x 1 x 512 x 4096 x 32 x 2 = 268435456 / 1024 / 1024 = 256.0 MB

所以显存比较极限的场景下,也需要注意 KV-Cache 的显存占用,虽然是随着 SeqLen 线性增长的,但是架不住维度和堆叠的 Decoder 多。 

Tips:

这里解释下计算时前后两个 2 怎么来的,第一个 2 是因为 KV cache 里 K/V 各缓存一次;第二个 2 是因为 FP16 16 位占用 2 个 byte。

三.KV-Cache 实践

接下来我们实践下 KV-Cache,由于是本地实验,所以采用比较小的 GPT-2 作为实验 LLM。

Generate 时主要通过 past_key_values 传递 KV-Cache:

1.WithOut KV-Cache

import time

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer


def common(in_tokens, model, tokenizer, is_log=False):
    # inference
    token_eos = torch.tensor([198])  # line break symbol
    out_token = None
    i = 0
    st = time.time()
    with torch.no_grad():
        while out_token != token_eos:
            logits, _ = model(in_tokens)
            out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
            in_tokens = torch.cat((in_tokens, out_token), 0)
            text = tokenizer.decode(in_tokens)
            if is_log:
                print(f'step {i} input: {text}', flush=True)
            i += 1
    end = time.time()

    out_text = tokenizer.decode(in_tokens)
    print(f'Input: {in_text}')
    print(f'Output: {out_text}')
    print(f"Total Cost: {end - st} Mean: {(end - st) / i}")

token_id = 198 为 GPT-2 的 <eos>,我们手动停止生成,这里可以看到每一个 token 的预测过程:

- logits 通过 model 计算 logits 概率

- argmax 通过 argmax 获取概率最大的 token

- input_tokens = in_tokens + new_token 持续追加 seq_len 长度

- text 通过 tokenizer decode 即可获取 token 转变后的字符 

2.With KV-Cache

def cache(in_tokens, model, tokenizer, is_log=False):
    # inference
    token_eos = torch.tensor([198])  # line break symbol
    out_token = None
    kvcache = None
    out_text = in_text
    i = 0
    st = time.time()
    with torch.no_grad():
        while out_token != token_eos:
            logits, kvcache = model(in_tokens, past_key_values=kvcache)  # 增加了一个 past_key_values 的参数
            out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
            in_tokens = out_token  # 输出 token 直接作为下一轮的输入,不再拼接
            text = tokenizer.decode(in_tokens)
            if is_log:
                print(f'step {i} input: {text}', flush=True)
            i += 1
            out_text += text
    end = time.time()

    print(f'Input: {in_text}')
    print(f'Output: {out_text}')
    print(f"Total Cost: {end - st} Mean: {(end - st) / i}")

Generate Process: 

3.Compare Efficiency

if __name__ == '__main__':
    local_path = "/LLM/model/gpt2"

    model = GPT2LMHeadModel.from_pretrained(local_path, torchscript=True).eval()

    # tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(local_path)
    in_text = "Cristiano Ronaldo is a"
    in_tokens = torch.tensor(tokenizer.encode(in_text))

    common(in_tokens, model, tokenizer)
    cache(in_tokens, model, tokenizer)

比较 common 和 cache 的生成效果和时间:

- Common

- Cache

- Efficient

生成的结果相同,Cache 只需 Common 耗时的 43% 左右即可完成相同的推理。

Tips:

由于本地测试且 seq 比较短,所以这里就不参考本机显存变化了,需要的话大家可以用前面公式计算一下,这里 GPT-2 的 dim = 768,layer_num = 12。

四.总结

Generate 流程、Attention 计算以及 KV-Cache 的流程大致就这么多,下面总结下:

- 注意 Scale

Attention 计算时有一个 Scale 的操作,图中没有标注,注意不要忘记。

- Generate

生成是一个 token 一个 token 生成的,后一个 token 基于前面的所有 token。

- Q Cache

KV-Cache,有的同学肯定有疑问,为啥不把 Q 也 Cache 了。因为:

 AttN 的生成需要 K1、,,,、KNV1、...、VN

还需要 QN,因此缓存 Q1、Q2 ... 对于计算 AttN 没有意义,而且每一个 Q 都是需要基于前面序列来生成。

- Use Cache

{
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,

...

  "use_cache": true,
  "vocab_size": 32000
}

当前新出的 LLM 模型都在 config 内置了 use_cache 参数,上面是 Mistral config 中的部分参数,KV-Cache 都是 infer 时默认开启的。

!! 最后感谢下面大佬们的输出:

大模型推理加速:看图学KV Cache

大模型推理性能优化之KV Cache解读

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐