Transformer中的KVCache优化原理

前记:现在KVCache已经属于是必备的技术了,但是博主发现自己只是听过这个名词,但是并不了解该技术的原理和实现,遂学习记录了本博客

1.首先KVCache技术有什么用?

答:KVCache技术主要帮助模型在推理过程,避免重复的计算,从而减少计算量,加快模型推理速度,同时也会带来成本的降低,设想如果是用户使用一款由该模型作为基座的大模型产品,那么更快的推理速度在用户层面会带来更好的用户体验,同时对于产品运营成本也有所降低。

2.KVCache技术是如何避免掉重复计算的?

由下图可知,token在输入模型计算的过程中,存在重复计算的部分,可以发现通过前n个token计算结果获得第n+1个token预测,然后通过这n+1个token再输入到网络中来计算第n+2个token,细细一想,聪明的你是不是发现了后面n+1个token计算的时候,前n个token是不是重复计算了,也就是说 n+1个token输入到模型中,实际上前面n个token的计算,和之前只有n个token输入到网络中的计算结果是一样的,这样就带来了重复计算问题,那我们如何避免掉重复计算问题,我们将前n个token的计算结果给缓存下来,那我们就只需要计算新加入这个token,从而大大降低了计算量。

LLMs计算

可能你看了上面的解释,还是有点云里雾里,不要着急,我们用代码来解释,从代码的结果来直观的来看到运算结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')

# text: "The quick brown fox jumps over the lazy"
tokens = [[464, 2068, 7586, 21831, 18045, 625, 262, 16931]]
input_n = torch.tensor(tokens)
output_n = model(input_ids=input_n, output_hidden_states=True)

# text: " dog"
tokens[0].append(3290)
input_n_plus_1 = torch.tensor(tokens)
output_n_plus_1 = model(input_ids=input_n_plus_1, output_hidden_states=True)

for i, (hidden_n, hidden_n_plus_1) in enumerate(zip(output_n.hidden_states, output_n_plus_1.hidden_states)):
print(f"layer {i}, max difference {(hidden_n - hidden_n_plus_1[:, :-1, :]).abs().max().item()}")
assert torch.allclose(hidden_n, hidden_n_plus_1[:, :-1, :], atol=1e-4)

运算结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
layer 0, max difference 0.0
layer 1, max difference 5.7220458984375e-06
layer 2, max difference 5.7220458984375e-06
layer 3, max difference 7.62939453125e-06
layer 4, max difference 2.86102294921875e-05
layer 5, max difference 1.9073486328125e-05
layer 6, max difference 9.5367431640625e-06
layer 7, max difference 1.9073486328125e-05
layer 8, max difference 2.6702880859375e-05
layer 9, max difference 2.6702880859375e-05
layer 10, max difference 2.6702880859375e-05
layer 11, max difference 3.0517578125e-05
layer 12, max difference 3.0517578125e-05

因此,我们可以用空间来换时间,牺牲掉内存空间来换取更快的推理速度。

参考:

[1]: https://zhuanlan.zhihu.com/p/686183300 “KVCache”
[2]: https://github.com/owenliang/pytorch-transformer/tree/kvcache# “KVCache项目”