for i, (hidden_n, hidden_n_plus_1) inenumerate(zip(output_n.hidden_states, output_n_plus_1.hidden_states)): print(f"layer {i}, max difference {(hidden_n - hidden_n_plus_1[:, :-1, :]).abs().max().item()}") assert torch.allclose(hidden_n, hidden_n_plus_1[:, :-1, :], atol=1e-4)
运算结果:
1 2 3 4 5 6 7 8 9 10 11 12 13
layer 0, max difference 0.0 layer 1, max difference 5.7220458984375e-06 layer 2, max difference 5.7220458984375e-06 layer 3, max difference 7.62939453125e-06 layer 4, max difference 2.86102294921875e-05 layer 5, max difference 1.9073486328125e-05 layer 6, max difference 9.5367431640625e-06 layer 7, max difference 1.9073486328125e-05 layer 8, max difference 2.6702880859375e-05 layer 9, max difference 2.6702880859375e-05 layer 10, max difference 2.6702880859375e-05 layer 11, max difference 3.0517578125e-05 layer 12, max difference 3.0517578125e-05