1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
| import torch import torch.nn as nn import torch.nn.functional as F
class MultiHeadAttention(nn.Module): def __init__(self, embed_size, num_heads, dropout=0.1): """ Args: embed_size (int): 输入嵌入的维度大小。 num_heads (int): 多头注意力中的头数。 dropout (float): Dropout 概率。 """ super(MultiHeadAttention, self).__init__() assert embed_size % num_heads == 0, "Embedding size must be divisible by num_heads" self.embed_size = embed_size self.num_heads = num_heads self.head_dim = embed_size // num_heads
self.values = nn.Linear(self.head_dim, embed_size, bias=False) self.keys = nn.Linear(self.head_dim, embed_size, bias=False) self.queries = nn.Linear(self.head_dim, embed_size, bias=False) self.fc_out = nn.Linear(embed_size, embed_size) self.dropout = nn.Dropout(dropout)
def forward(self, values, keys, query, mask=None): """ 前向传播函数。 Args: values (torch.Tensor): 值向量,形状为 [batch_size, value_len, embed_size]。 keys (torch.Tensor): 键向量,形状为 [batch_size, key_len, embed_size]。 query (torch.Tensor): 查询向量,形状为 [batch_size, query_len, embed_size]。 mask (torch.Tensor): 掩码张量,形状为 [batch_size, 1, 1, key_len]。 Returns: torch.Tensor: 输出张量,形状为 [batch_size, query_len, embed_size] """ batch_size = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
values = values.reshape(batch_size, value_len, self.num_heads, self.head_dim) keys = keys.reshape(batch_size, key_len, self.num_heads, self.head_dim) queries = query.reshape(batch_size, query_len, self.num_heads, self.head_dim)
values = values.transpose(1, 2) keys = keys.transpose(1, 2) queries = queries.transpose(1, 2)
attn_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt( torch.tensor(self.head_dim, dtype=torch.float32) )
if mask is not None: attn_scores = attn_scores.masked_fill(mask == 1, float("-inf"))
attention = F.softmax(attn_scores, dim=-1) x = torch.matmul(self.dropout(attention), values)
x = x.transpose(1, 2).contiguous() x = x.reshape(batch_size, query_len, self.embed_size)
return self.fc_out(x)
|