LLM - Transformer Multi-Head Attention 维度变化与源码详解

news/2024/4/25 18:55:06

一.引言

前面我们基于 LLM 大模型源码介绍了 Causal Mask 以及 ROPE 旋转位置编码的实现,本文介绍源码中 Transformer 的实现流程,我们基于代码逐行分析维度变化与代码含义,希望能够清晰的了解 LLM 中 Transformer 运行的流程。

二.Transformer 分层维度

上面这个 Transformer 的基础结构我们在之前已经提到过很多次,这里结合维度变化再啰嗦一次,更详细的介绍可以参考: LLM - Transformer && LLaMA2 结构分析与 LoRA 详解。

1.单条样本

- Embedding Layer

对于一个典型的 LLM 大模型,输入 Embedding 层的维度 d_model 通常指的是将输入的标记 token 通过一个 embedding 层映射转换为连续向量的维度。例如,在 BERT-base 模型中,d_model 是 768,而在当下大模型中 d_model 为 8192。

- Transformer Layer

Transformer 层的输出维度通常和输入 Embedding 层的维度一致,即 d_model。如果我们持续使用 BERT-base 的例子,那么每个 Transformer 层 [ BERT中称为encoder层,LLM 中多为 decoder 层 ] 的输出也将是维度为 768 / 8192 的向量。

- lm_head Layer 

最后的 lm_head(语言模型头)的维度通常等于词汇表的大小 vocab_size,因为 lm_head 的作用是将 Transformer 层的输出转换成每个词汇的概率分布。举例来说,如果模型处理的语言的词汇表大小为 30000 个单词,那么 lm_head 的输出维度就是 30000。

- hidden_states

hidden_states 是 Transformer 模型处理过程中的一个术语,常见于模型的中间输出和内部分析。其记录了隐层的激活值,对于每个输入标记 token,Transformer 的每个层都会有一个输出向量,它表示的是在该层的特定深度上输入的表示。对于一个 N 层堆叠的 Transformer 模型,对于一个给定的输入序列,模型将会有 N 个这样的隐藏状态集。其中每个隐藏状态也会包含注意力分布,这是 Transformer 的自注意力机制的一个关键组成部分,它允许模型在处理输入时衡量不同部分之间的相互依赖性。

Tips:

假设我们有一个 BERT-base 模型,它使用 12 层 Transformer,每层的输出维度为 768,若输入一个有 5 个 tokens 的序列,每个 token 会首先被转换成一个 768 维的 embedding 向量。因此,hidden_states 在模型刚开始时会是一个形状为 (5, 768) 的张量。经过 12 层 Transformer 层处理后最后输出的 hidden_states 将会是一个形状为 (12, 5, 768) 的 3 维张量,其中包含了序列中每个token 在各个层上的表征。

2.批次样本

上面给出了单条样本的转换流程,下面我们分析下 batch_size 情况下维度的变换流程。假设我们有一个 BERT-base 模型:

词汇表大小 vocab_size = 30000

嵌入层维度 d_model = 768

堆叠层数量 N = 12

最大序列长度 max_seq_length = 128

批次大小 batch_size = 32

以下是数据通过模型时维度的具体变化过程:

- Input Layer

输入层维度为 (batch_size, max_seq_length) 即 (32, 128),每一个 128 的张量表示批次中每个序列的 token_id,即 text 通过 tokenizer 处理后的结果。 

- Embedding Layer

(bsz, max_seq_length) 的整数张量会被送入 Embedding 层,以 Bert 为例,其会被映射到 (bsz, max_seq_length, d_model) 的维度,即 (32, 128, 768)。这表明我们现在有 32 条样本,每个序列有 128 个 768 维的词嵌入向量。

- Transformer Layer

每个 Transformer 层接受一个 (bsz, max_seq_length, d_model) 的张量,经过 multi_head_attention 后输出一个相同形状的张量,这是因为 transformer 层通常会保持输入输出的维度相同,因此经过本层映射后,维度依然为 (bsz, max_seq_length, d_model) 即 (32, 128, 768)。

- lm_head Layer

lm_head 线性层将 Transformer 层的输出 (bsz, max_seq_length, d_model) 转换为 (bsz, max_seq_length, vocab_size) 的张量,即 (32, 128, 30000)。这一层一般是通过 Linear 实现的,对于复杂的 LLM,还会有 MLP 层,但最终 lm_head 的目的都是将 d_model 映射到 vocab_size,即生成一个与词汇表大小匹配的权重矩阵,代表每个 token 可能性的分布。

Tips:

如果考虑中间的 hidden_states,那么对于序列中的每个 token,在每个 Transformer 层中,我们都会得到一个 768 维的向量。因此,对于整个 batch 来说,每一层的 hidden_states 的形状为(batch_size, max_seq_length, d_model),即 (32, 128, 768)。如果我们保存所有层的hidden_states,那么我们就得到了一个形状为 (num_layers, batch_size, max_seq_length, d_model) 的 4 维张量,即 (12, 32, 128, 768),这里 num_layers 就是前面提到的 N,即 LLM 中 transformer 层堆叠的数量,这样,你就可以看到不同维度如何随着数据流通过模型而变化。这里需要注意的是真实情况下由于序列化长度可能不同,还会涉及到填充 padding 和掩码 masking 来确保批量处理是有效的,然而这并不影响上述维度变化的基本流程。

三.Transformer 维度变换

为了大家可以在本机 debug 快速测试,下面的示例我们以 Bert 及其 tokenizer 作为基模型构建 token_id 以及 Embedding,后续的 Multi-Head Attention 我们基于 Qwen 的逻辑进行了迁移,保持主体实现风格不变,更完整的代码可以参考 HF 上 modeling.py。

1.Input Layer 

输入层以及嵌入层我们通过 Bert 模型的 tokenizer 获取:

#!/usr/bin/python
# -*- coding: UTF-8 -*-import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizerif __name__ == '__main__':tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')pretrained_bert = BertModel.from_pretrained('bert-base-uncased')input_texts = ["This is a test sentence.", "Here is another test sentence."]input_ids = [tokenizer.encode(text, add_special_tokens=True, max_length=10, padding='max_length', truncation=True,return_tensors='pt') for text in input_texts]input_ids = torch.cat(input_ids, dim=0)  # Concatenate and add batch dimension

为了方便我们 input_texts 构造两条样本,所以 bsz = 2、max_length = 10、d_model = 768,input_ids 维度为 (10, ):

通过 concat 得到 (bsz, max_length) = (2, 10) 的初始维度:

tensor([[ 101, 2023, 2003, 1037, 3231, 6251, 1012,  102,    0,    0],[ 101, 2182, 2003, 2178, 3231, 6251, 1012,  102,    0,    0]])

2.Embedding Layer

    with torch.no_grad():embedded_output = pretrained_bert(input_ids)[0]  # Get the output of the BERT modelprint(embedded_output.size())  # Output shape should be (2, 10, embedding_dim)

这里通过 bert 的 Embedding 层获取 input_id 对应的 Embedding,由于 d_model = 768,所以前面 token_id 的 (bsz, max_length) 转换为 (bsz, max_length, d_model) 即 (2, 10, 768):

tensor([[[-3.7545e-02,  5.3234e-04, -1.3553e-02,  ..., -1.9545e-01,2.3569e-01,  4.7479e-01],[-7.1746e-01, -2.8763e-01,  1.4100e-01,  ..., -5.5593e-01,6.1830e-01,  3.9255e-01],[-1.9318e-01, -4.0202e-01,  3.2924e-01,  ..., -1.5206e-01,3.4014e-01,  1.0233e+00],...,[ 1.5273e-01,  1.1651e-01,  1.5754e-01,  ...,  6.9833e-02,-8.5732e-01, -4.3875e-02],[ 7.0679e-02, -2.3521e-01,  6.1713e-01,  ..., -7.3852e-02,2.5070e-01, -6.3240e-02],[-1.3249e-01, -3.6026e-01,  3.5025e-01,  ..., -5.5981e-02,1.0420e-01, -4.3954e-01]],[[-2.9592e-02, -1.4164e-01, -2.2295e-03,  ..., -1.3087e-01,2.9421e-01,  5.5132e-01],[-1.0146e+00, -6.8757e-01,  1.9959e-01,  ..., -4.2000e-01,1.7332e-01,  9.2754e-02],[-1.3425e-01, -8.1044e-01,  2.6674e-01,  ...,  4.6978e-02,-1.0026e-01,  4.5293e-01],...,[ 4.5527e-01,  2.2234e-02, -3.6816e-01,  ...,  4.3154e-01,-8.6396e-01, -2.8542e-01],[ 1.4188e-01, -2.4001e-01,  6.5681e-01,  ..., -5.7224e-02,3.1025e-01, -9.0286e-02],[-3.9205e-02, -3.2815e-01,  4.7910e-01,  ..., -4.7641e-02,2.9916e-02, -4.5328e-01]]])

3.Multi-Head Attention

        embed_dim = embedded_output.size(-1)num_heads = 4model = BITDDDAttention(embed_dim, num_heads)output = model(embedded_output)print(output.size())  # Output shape should be (2, 10, embed_dim)

本层我们从 LLM modeling.py 中将 Atention 的核心部分迁移到 BITDDDAtention Class 中:

class BITDDDAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(BITDDDAttention, self).__init__()self.embed_dim = embed_dim  # embedding 维度self.num_heads = num_heads  # head 数量assert embed_dim % num_heads == 0self.head_dim = embed_dim // num_heads# 构建 Q/K/V 向量以及最后的全连接 MLPself.query = nn.Linear(embed_dim, embed_dim)self.key = nn.Linear(embed_dim, embed_dim)self.value = nn.Linear(embed_dim, embed_dim)self.fc_out = nn.Linear(embed_dim, embed_dim)def forward(self, x):batch_size, seq_len, _ = x.size()# Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)# Compute the attention scoresattention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5attention_probs = torch.softmax(attention_scores, dim=-1)# Apply the attention weights to the valueattention_output = torch.matmul(attention_probs, value).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)# Apply a linear layer to the outputx = self.fc_out(attention_output)return x

下面我们逐行看下 Mutil-Head Attention 的执行流程与维度变化:

- Size

batch_size, seq_len, _ = x.size()

这一步解析 Attention 层输入的 batch 样本的 bsz、seq_len,由于 init 方法中已经给出了 emd_dim,所以这里使用 '_' 忽略。

- Q/K/V 获取

# Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)
query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

self.query、key、value 都是 nn.Linear(embed_dim, embed_dim) 的线性转换层,Q/K/V 的处理逻辑是相同的,这里通过 view 即 resize 方法将线性转换后的向量 (bsz, seq_len, embed_dim) 转换为 (bsz, seq_len, num_heads, head_dim),最后通过 permute 交换位置得到 (bsz, num_heads, seq_len, head_dim) 的输出向量,用于后续 multi-head 的计算。这里通过 assert 判断是否整除:

self.head_dim = embed_dim // num_heads

根据上面 init 给出的 heads 以及 embed_dim,可以得到最终维度为: (2, 4, 10, 192):

tensor([[[[ 0.0185, -0.1872,  0.1827,  ..., -0.7914, -0.0074, -0.6228],[-0.1390,  0.4675,  0.0325,  ...,  0.0187,  0.0912, -0.2692],[-0.1342,  0.2904, -0.2637,  ...,  0.1130, -0.0226, -0.3510],...,[-0.1151, -0.2627, -0.6453,  ...,  0.4885, -0.1982, -0.1538],[-0.1281,  0.2321, -0.0815,  ..., -0.1740,  0.4909, -0.1373],[-0.1364,  0.2844, -0.0728,  ..., -0.0620,  0.3605, -0.2292]],[[ 0.3641,  0.1707, -0.0567,  ...,  0.0267,  0.3272,  0.1560],[-0.1206,  0.6853,  0.0990,  ..., -0.0875,  0.2414,  0.5490],[-0.4080,  0.0679,  0.3174,  ...,  0.0970, -0.0127,  0.1664],...,[-0.2878,  0.2856,  0.0777,  ..., -0.0791,  0.0847,  0.0545],[ 0.2381, -0.1032,  0.2887,  ...,  0.2219,  0.2837,  0.0345],[ 0.1421, -0.0956,  0.1983,  ...,  0.1784,  0.1827,  0.0776]],[[-0.2031, -0.2496, -0.0072,  ..., -0.1553, -0.0441,  0.0200],[-0.2028, -0.4097,  0.1779,  ...,  0.0333, -0.4005, -0.3453],[ 0.0926, -0.1818,  0.0492,  ...,  0.3059, -0.6175, -0.2858],...,[ 0.3494, -0.4813,  0.7086,  ...,  0.6181,  0.1515, -0.1279],[-0.0542,  0.3148,  0.0172,  ...,  0.0037, -0.2878, -0.1582],[-0.1381,  0.2450,  0.0490,  ..., -0.0824, -0.2504, -0.2464]],[[ 0.6905, -0.1202,  0.6489,  ...,  0.6069,  0.2634, -0.0595],[ 0.3937, -0.2795,  0.7692,  ...,  0.1321, -0.0240, -0.1484],[ 0.2260, -0.4332,  0.4651,  ..., -0.1797, -0.1127, -0.3294],...,[ 0.0168, -0.2892,  0.4032,  ..., -0.4515,  0.3833, -0.7699],[ 0.1970, -0.3264,  0.4196,  ...,  0.3044, -0.0819, -0.2083],[ 0.2492, -0.3419,  0.5813,  ...,  0.1855, -0.2431, -0.1149]]],[[[ 0.0225, -0.2359,  0.0754,  ..., -0.7577,  0.0936, -0.6233],[ 0.0479,  0.5459, -0.3047,  ..., -0.3134,  0.0416,  0.0397],[ 0.1172,  0.2506, -0.5461,  ...,  0.1287, -0.0441, -0.2074],...,[-0.3586, -0.3827, -0.6436,  ...,  0.3915, -0.2485, -0.1576],[-0.1502,  0.1852, -0.1007,  ..., -0.1310,  0.5079, -0.1868],[-0.1622,  0.2055, -0.1428,  ..., -0.0887,  0.3516, -0.2383]],[[ 0.4338,  0.2326, -0.0661,  ...,  0.0309,  0.3088,  0.1711],[-0.4011,  0.9250,  0.2983,  ..., -0.4108,  0.4223,  0.6880],[-0.2721,  0.4383,  0.6376,  ..., -0.0888, -0.0647, -0.0073],...,[ 0.1742,  0.2020, -0.1020,  ..., -0.1444,  0.2459,  0.1079],[ 0.2608, -0.0978,  0.2557,  ...,  0.2132,  0.2125,  0.0010],[ 0.1041, -0.1335,  0.1523,  ...,  0.1797,  0.1323,  0.0036]],[[-0.1826, -0.2200, -0.0026,  ..., -0.1664, -0.0773,  0.0607],[-0.1257, -0.2642,  0.6933,  ...,  0.4202, -0.1153, -0.3960],[-0.1353, -0.4837,  0.3527,  ...,  0.3592, -0.5616, -0.3685],...,[ 0.6056, -0.3298,  0.7872,  ...,  0.3984,  0.4775,  0.2213],[-0.1211,  0.3394, -0.0247,  ...,  0.0251, -0.3108, -0.1656],[-0.1572,  0.3040,  0.0164,  ..., -0.1026, -0.2737, -0.2175]],[[ 0.7236, -0.1187,  0.6491,  ...,  0.6230,  0.2401, -0.0061],[-0.0402, -0.0318,  0.7717,  ..., -0.0389,  0.1465, -0.3047],[ 0.2734, -0.4473,  0.6278,  ..., -0.3827, -0.0412, -0.7133],...,[-0.0418,  0.0670,  0.1462,  ..., -0.6109,  0.4838, -0.4277],[ 0.2340, -0.3250,  0.4256,  ...,  0.3217, -0.0688, -0.1837],[ 0.1940, -0.2878,  0.5281,  ...,  0.2155, -0.1810, -0.0649]]]])

- Attention Score 计算

# Compute the attention scores
attention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5
attention_probs = torch.softmax(attention_scores, dim=-1)

Attention Score 的计算依赖于 Q/K,这里把 key 的维度通过 permute 做了转换,由  (bsz, num_heads, seq_len, head_dim) 变换为  (bsz, num_heads, head_dim, seq_len),matmul 相乘后得到 attention_scores 的维度为 (bsz, num_heads, seq_len, seq_len) 即 (2, 4, 10, 10),除以 sqrt(head_dim) 是在应用 scale_dot 防止 matmul 的乘积过大,而最后 softmax(dim=-1) 则将 Attention Score 的最后一维的 10 个数字进行了归一化:

tensor([[[[0.1188, 0.1000, 0.0912, 0.0963, 0.0984, 0.0862, 0.1004, 0.0955,0.1040, 0.1093],[0.1077, 0.0949, 0.0963, 0.1035, 0.0961, 0.0940, 0.0946, 0.0890,0.1090, 0.1149],[0.0965, 0.1010, 0.0930, 0.0972, 0.1032, 0.1031, 0.0989, 0.0982,0.1062, 0.1026],[0.0932, 0.1033, 0.0970, 0.0977, 0.0961, 0.1050, 0.0990, 0.1082,0.1006, 0.0999],[0.0947, 0.1033, 0.0949, 0.0945, 0.0957, 0.1036, 0.0964, 0.0985,0.1083, 0.1102],[0.0941, 0.1026, 0.0939, 0.0942, 0.0953, 0.1008, 0.1001, 0.1089,0.1038, 0.1063],[0.1017, 0.1019, 0.1008, 0.0937, 0.1095, 0.1007, 0.0913, 0.0832,0.1092, 0.1079],[0.0926, 0.1121, 0.1009, 0.0991, 0.0955, 0.1017, 0.0964, 0.1030,0.1004, 0.0982],[0.1010, 0.1021, 0.0948, 0.0954, 0.0976, 0.1024, 0.0916, 0.1032,0.1049, 0.1071],[0.1046, 0.1077, 0.0932, 0.0948, 0.1006, 0.1002, 0.0934, 0.0983,0.1042, 0.1032]],......         [[0.1047, 0.0999, 0.1045, 0.1054, 0.0979, 0.1071, 0.0863, 0.0884,0.1012, 0.1046],[0.1019, 0.1113, 0.1010, 0.0990, 0.0981, 0.1060, 0.0872, 0.0915,0.1017, 0.1022],[0.0977, 0.0996, 0.0993, 0.1027, 0.0970, 0.0985, 0.0977, 0.0990,0.1069, 0.1018],[0.1042, 0.1125, 0.1049, 0.1022, 0.0981, 0.0950, 0.0864, 0.0876,0.1045, 0.1046],[0.0987, 0.1151, 0.1018, 0.0956, 0.0923, 0.0955, 0.0938, 0.0910,0.1073, 0.1087],[0.0985, 0.1143, 0.0936, 0.1029, 0.0954, 0.1028, 0.0857, 0.0901,0.1076, 0.1092],[0.0961, 0.0908, 0.1013, 0.1055, 0.0992, 0.1035, 0.0919, 0.0971,0.1090, 0.1056],[0.0874, 0.0935, 0.1012, 0.1057, 0.1044, 0.0968, 0.0936, 0.0950,0.1132, 0.1091],[0.1041, 0.1118, 0.0958, 0.0968, 0.0971, 0.1044, 0.0925, 0.0906,0.1028, 0.1041],[0.1028, 0.1123, 0.0971, 0.1005, 0.1013, 0.1020, 0.0896, 0.0894,0.1026, 0.1023]]]])

- Attention Output 

# Apply the attention weights to the value
attention_output = torch.matmul(attention_probs, value).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)

Attention Probs 的维度为  ​​(2, 4, 10, 10) ,value 的维度为 (2, 4, 10, 192),相乘后得到 (2, 4, 10, 192) 即 (bsz, num_heads, seq_len, head_dim),通过 permute 转换为 (bsz, seq_len, num_heads, head_dim),再通过 view 将后两维 num_heads x head_dim 合并为 d_model,从而最终 attention_output 的维度为 (bsz, seq_len, d_model) 与原始 token_ids 通过 Embedding 层映射后的向量维度保持一致。

tensor([[[-0.0421,  0.0127, -0.3383,  ...,  0.1617, -0.2079, -0.3181],[-0.0401,  0.0172, -0.3488,  ...,  0.1581, -0.2098, -0.3260],[-0.0397,  0.0130, -0.3420,  ...,  0.1581, -0.2081, -0.3266],...,[-0.0428,  0.0137, -0.3372,  ...,  0.1589, -0.2145, -0.3305],[-0.0424,  0.0120, -0.3419,  ...,  0.1574, -0.2198, -0.3344],[-0.0426,  0.0140, -0.3463,  ...,  0.1562, -0.2191, -0.3338]],[[ 0.0502,  0.0926, -0.3300,  ...,  0.1376, -0.1264, -0.3689],[ 0.0369,  0.0917, -0.3339,  ...,  0.1439, -0.1089, -0.3571],[ 0.0419,  0.0915, -0.3328,  ...,  0.1480, -0.1168, -0.3654],...,[ 0.0438,  0.0946, -0.3290,  ...,  0.1435, -0.1302, -0.3702],[ 0.0417,  0.0898, -0.3281,  ...,  0.1358, -0.1374, -0.3759],[ 0.0428,  0.0906, -0.3306,  ...,  0.1345, -0.1330, -0.3752]]])

- Linear 浅层 MLP

# Apply a linear layer to the output
x = self.fc_out(attention_output)

fc_out 的维度是 nn.Linear(embed_dim, embed_dim),所有 attention_output 经过处理后 (bsz, seq_len, d_model) x (d_model, d_model) = (bsz, seq_len, d_model)。

tensor([[[ 1.0228e-01,  1.6250e-01, -1.4914e-01,  ..., -1.7511e-01,-2.1751e-03, -2.0877e-02],[ 9.9930e-02,  1.6427e-01, -1.4394e-01,  ..., -1.7894e-01,1.9605e-03, -2.4290e-02],[ 1.0188e-01,  1.6577e-01, -1.4313e-01,  ..., -1.7274e-01,5.3616e-03, -1.8874e-02],...,[ 1.0584e-01,  1.6541e-01, -1.4315e-01,  ..., -1.7077e-01,-4.8522e-04, -2.2207e-02],[ 1.0028e-01,  1.6638e-01, -1.3908e-01,  ..., -1.7138e-01,-4.0303e-05, -2.2604e-02],[ 1.0054e-01,  1.6448e-01, -1.4135e-01,  ..., -1.7086e-01,2.8514e-03, -1.9951e-02]],[[ 4.9912e-02,  1.3306e-01, -1.2705e-01,  ..., -1.2117e-01,3.5498e-02,  3.8191e-03],[ 4.8556e-02,  1.3361e-01, -1.2207e-01,  ..., -1.2270e-01,3.7410e-02, -3.3710e-03],[ 4.9592e-02,  1.3507e-01, -1.2446e-01,  ..., -1.2247e-01,4.3996e-02,  2.0591e-03],...,[ 5.2688e-02,  1.3105e-01, -1.2519e-01,  ..., -1.1373e-01,3.7038e-02,  2.5118e-03],[ 4.8786e-02,  1.3443e-01, -1.1793e-01,  ..., -1.1811e-01,3.4455e-02,  3.0611e-04],[ 4.7252e-02,  1.3401e-01, -1.1889e-01,  ..., -1.1601e-01,3.6708e-02,  2.7476e-03]]])

4.完整代码

#!/usr/bin/python
# -*- coding: UTF-8 -*-import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizerclass BITDDDAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(BITDDDAttention, self).__init__()self.embed_dim = embed_dim  # embedding 维度self.num_heads = num_heads  # head 数量assert embed_dim % num_heads == 0self.head_dim = embed_dim // num_heads# 构建 Q/K/V 向量以及最后的全连接 MLPself.query = nn.Linear(embed_dim, embed_dim)self.key = nn.Linear(embed_dim, embed_dim)self.value = nn.Linear(embed_dim, embed_dim)self.fc_out = nn.Linear(embed_dim, embed_dim)def forward(self, x):batch_size, seq_len, _ = x.size()# Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)# Compute the attention scoresattention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5attention_probs = torch.softmax(attention_scores, dim=-1)# Apply the attention weights to the valueattention_output = torch.matmul(attention_probs, value).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)# Apply a linear layer to the outputx = self.fc_out(attention_output)return xif __name__ == '__main__':tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')pretrained_bert = BertModel.from_pretrained('bert-base-uncased')input_texts = ["This is a test sentence.", "Here is another test sentence."]input_ids = [tokenizer.encode(text, add_special_tokens=True, max_length=10, padding='max_length', truncation=True,return_tensors='pt') for text in input_texts]input_ids = torch.cat(input_ids, dim=0)  # Concatenate and add batch dimensionwith torch.no_grad():embedded_output = pretrained_bert(input_ids)[0]  # Get the output of the BERT modelprint(embedded_output.size())  # Output shape should be (2, 10, embedding_dim)embed_dim = embedded_output.size(-1)num_heads = 4model = BITDDDAttention(embed_dim, num_heads)output = model(embedded_output)print(output.size())  # Output shape should be (2, 10, embed_dim)

四.总结

上述代码可以在本地 CPU/GPU 环境跑起来,大家可以自己打断点熟悉整个过程维度的变化,计算的流程,Multi-Head Attention 分多个 head 计算不同 token 的注意力权重并加权求和,对于 Decoder-Only 的架构,其还会添加 Causal Mask 保证前面的文字看不到后面的文字。本文先介绍到 Transformer 的输出,后续我们介绍如何通过 Transformer 最后一层 lm_head 的输出计算 next_token 的概率并计算交叉熵 loss。


https://www.xjx100.cn/news/3280849.html

相关文章

Java Web(六)--XML

介绍 官网:XML 教程 为什么需要: 需求 1 : 两个程序间进行数据通信?需求 2 : 给一台服务器,做一个配置文件,当服务器程序启动时,去读取它应当监听的端口号、还有连接数据库的用户名和密码。spring 中的…

大型语言模型(LLM, Large Language Models)基模和 Chat 模型之间的区别

一、概述 最近看大模型相关的知识,有看到大模型都有基础模型(base)和对话模型(chat),不太清楚什么时候用到基础模型,什么时候用到对话模型,故有此文。 通过了解,最简单…

python从小白到大师-第一章Python应用(八)应用领域与常见包-自动化办公word

目录 一.python-docx 二.pypiwin32 一.python-docx Python-docx是一个用于创建、修改和读取Microsoft Word文件(.docx)的Python库。它提供了一组丰富的功能,使开发人员能够使用Python生成自定义的Word文档。 以下是python-docx库的一些主…

剪辑视频调色软件有哪些 剪辑视频软件哪个最好 剪辑视频怎么学 剪辑视频的方法和步骤 会声会影2024 会声会影视频制作教程

看了很多调色教程,背了一堆调色参数,可最终还是调不出理想的效果。别再怀疑自己了,不是你的剪辑技术不行,而是剪辑软件没选对。只要掌握了最基本的调色原理,一款适合自己的视频剪辑软件是很容易出片的。 有关剪辑视频…

绝地求生:图纸的加量不加价是否预示着蓝洞经营模式的转变

成长型武器目前作为PUBG中除了究极异色皮肤外的最高等级武器(传说级),也是PUBG核心利润来源,十分的珍贵。 一把成长型武器的保底价格为3000碎片,而每次通过G-coin抽取会赠送10个碎片,也就是需要抽取三百次&…

PHP小程序 获取二维码

//获取token public function getAccessToken($appId,$appSecret) {// 请求API获取 access_token$url "https://api.weixin.qq.com/cgi-bin/token?grant_typeclient_credential&appid{$this->appId}&secret{$this->appSecret}";$result $this->g…

Vue 使用 v-bind 动态绑定 CSS 样式

在 Vue3 中&#xff0c;可以通过 v-bind 动态绑定 CSS 样式。 语法格式&#xff1a; color: v-bind(数据); 基础使用&#xff1a; <template><h3 class"title">我是父组件</h3><button click"state !state">按钮</button&…

智能手机办公和PC电脑办公的区别,智能手机对PC电脑产生那些影响,

智能手机办公和PC电脑办公在工作方式、灵活性和便捷程度等方面有着显著的区别。本文将详细探讨这两种工作方式的优势和不同之处&#xff0c;并分析智能手机对PC电脑办公所产生的影响。 首先&#xff0c;智能手机办公相较于PC电脑办公具有更高的灵活性。由于智能手机轻便易携&am…