1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
| assert 10 % 2 == 0, "wrong assert"
import torch
import torch def creat_pe_absolute_sincos_embedding_gov(n_pos_vec, dim): assert dim % 2 == 0, "wrong dim" position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)
omega = torch.arange(dim//2, dtype=torch.float) omega /= dim/2. omega = 1./(10000**omega)
sita = n_pos_vec[:,None] @ omega[None,:] emb_sin = torch.sin(sita) emb_cos = torch.cos(sita)
position_embedding[:,0::2] = emb_sin position_embedding[:,1::2] = emb_cos
return position_embedding
def create_pe_absulute_sincos_embedding(n_pos_vec, dim): """ 绝对位置编码 :param n_pos_vec: 位置编码的长度向量 :param dim: 词向量的维度 :return: 位置编码 """ assert dim % 2 == 0, "dim must be even" position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)
omega = torch.arange(dim // 2, dtype=torch.float) omega *= 2 omega /= dim
omega = torch.pow(10000, omega) omega = 1 / omega omega = omega
print("n_pos_vec shape:",n_pos_vec.unsqueeze(1).shape) print("omega shape:", omega.shape).squeeze
position_embedding[:, 0::2] = torch.sin(n_pos_vec.unsqueeze(1) * omega) position_embedding[:, 1::2] = torch.cos(n_pos_vec.unsqueeze(1) * omega)
return position_embedding
if __name__ == "__main__": n_pos = 4 dim = 8 n_pos_vec = torch.arange(n_pos, dtype=torch.float) position_embeddding = create_pe_absulute_sincos_embedding(n_pos_vec, dim) position_embeddding_1 = creat_pe_absolute_sincos_embedding_gov(n_pos_vec, dim) print(position_embeddding == position_embeddding_1) print("position embedding shape:", position_embeddding.shape)
|