Attention is all you need.
德语->英语项目:
https://github.com/TCcjx/pytorch_transformer-remake
项目目录结构(文件说明):
.data(文件夹):数据集Multi30K
checkpoints(文件夹):训练权重文件保存处
config.py:全局配置文件,DEVICE以及输入token的最大长度
dataset.py:数据预处理文件,构建德语和英语词表,实现德语和英语的词元token和IDX的一一隐射,以及德语和英语句子预处理函数(输入德语和英语句子,返回分词后的词元列表信息,以及词元ID列表)
multihead_attn.py: 构建多头注意力机制模块,这里的实现同时也考虑了解码器中第二个多头注意力机制模块的代码复用,在编码器和解码器的多头注意力机制模块中都可以复用这个多头注意力机制的模块
encoder_block.py:编码器模块的构建
encoder.py:编码器的实现,同时自动处理PAD掩码矩阵,再传入encoder_block中,实现多个encoder_block的堆叠使用
decoder_block.py:解码器模块的构建
decoder.py: 解码器的实现,同时实现两个掩码矩阵,第一个掩码矩阵主要是PAD掩码以及三角掩码(遮蔽后续词元), 第二个掩码矩阵主要是为了遮蔽encoder隐藏表示中的PAD填充,防止解码器 Q到PAD
transformer:将编码器和解码器组装到一起,得到最后的词表大小的输出预测,用于预测下一个词元
train.py:训练文件
evaluation:预测效果测试评估文件
知识碎片-交叉熵损失函数
在 PyTorch 中,CrossEntropyLoss
是最常用的损失函数之一,适用于分类任务。下面详细介绍其传参方式和注意事项:
基本参数
1 | torch.nn.CrossEntropyLoss( |
输入要求
CrossEntropyLoss
要求两个输入:
预测值 (
**input**
):- 形状:
(N, C, d1, d2, ...)
,其中:N
:批次大小(Batch size)C
:类别数(Number of classes)d1, d2, ...
:可选的空间维度(如序列长度、高度、宽度等)
- 类型:未归一化的对数概率(logits),不需要经过 Softmax。
- 形状:
目标值 (
**target**
):- 形状:
(N, d1, d2, ...)
(不包含类别维度) - 类型:类别索引(整数,范围
0
到C-1
)。
- 形状:
常见场景示例
1. 标准分类(二维输入)
1 | import torch |
2. 序列分类(三维输入)
1 | # 输入: (batch_size, sequence_length, num_classes) |
3. 图像分割(四维输入)
1 | # 输入: (batch_size, num_classes, height, width) |
关键注意事项
无需手动 Softmax:
CrossEntropyLoss
内部包含LogSoftmax
和NLLLoss
,输入直接使用未归一化的 logits 即可。目标值格式:必须是类别索引(整数),而非 one-hot 编码。
权重平衡:通过
weight
参数可为不同类别设置权重,处理类别不平衡问题。1
2
3# 示例:为少数类设置更高权重
weights = torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0]) # 第1类权重为2,其余为1
loss_fn = nn.CrossEntropyLoss(weight=weights)忽略特定类别:通过
ignore_index
参数可忽略某些目标值(如填充标记)。
常见错误
- 错误1:输入经过了 Softmax 处理。
- 解决:直接使用模型输出的 logits。
- 错误2:目标值包含超出类别范围的索引。
- 解决:确保目标值范围在
0
到C-1
之间。
- 解决:确保目标值范围在
- 错误3:维度不匹配(如类别维度位置错误)。
- 解决:使用
transpose
或permute
调整维度。
- 解决:使用