最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Multihead attention : Simple Implementation vs Logical Split of Query, Key and Value Matrix - Stack Overflow

programmeradmin3浏览0评论

Studying "Build a Large Language Model (From Scratch) by Sebastian Raschka. I was trying to evaluate "Multi Head Simple Implementation" and Alternative Implmentation as mentioned in the book using following code. As per my understanding both method should give me same output. But in my following code values are different (i.e context_vec_Final in Simple implementation and context_vec in Alternative implementation should be same). Can you please help me to correct my mistake. For simplicity purpose, I discarded Causal Attention in this example.

#Simple Implementation
import torch
import torch.nn as nn
torch.manual_seed(1223)

batch = 1
num_tokens = 2
d_out  = 6
num_heads = 3
head_dim = d_out // num_heads

out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)
print (f'{batch = } {num_tokens = } {d_out = } {num_heads = } {head_dim = }')
q = torch.rand(batch,num_tokens,d_out)
k = torch.rand(batch,num_tokens,d_out)
v = torch.rand(batch,num_tokens,d_out)
print(f'{q.shape = } {k.shape = } {v.shape = }')
context_vec = []
for _ in range(num_heads) :
    attn_scores1 = q @ k.transpose(1, 2)
    context_vec1 = attn_scores1 @ v
    context_vec.append(context_vec1)

context_vec_Final = out_proj(torch.cat (context_vec, dim = -1))
print(f'{context_vec_Final.shape = } {context_vec_Final = }')

context_vec_Final.shape = torch.Size([1, 2, 18]) context_vec_Final = tensor([[[ 0.4481, 0.1600, 1.1673, -0.0284, -0.4112, -0.0324, -0.8480, -0.6424, 1.0077, 0.2438, -0.2781, 0.6076, -0.4314, 0.4139, 0.6364, -0.7987, -0.1409, 1.0451], [ 0.6571, 0.2290, 1.5520, -0.0847, -0.6139, -0.0287, -1.0902, -0.9133, 1.2716, 0.2759, -0.2912, 0.8003, -0.5354, 0.4829, 0.8830, -0.9920, -0.1232, 1.3950]]], grad_fn=)

Alternative Variation

import torch
torch.manual_seed(1223)

batch = 1
num_tokens = 2
d_out  = 6
num_heads = 3
head_dim = d_out // num_heads

print (f'{batch = } {num_tokens = } {d_out = } {num_heads = } {head_dim = }')
q = torch.rand(batch,num_tokens,d_out)
k = torch.rand(batch,num_tokens,d_out)
v = torch.rand(batch,num_tokens,d_out)
print(f'{q.shape = } {k.shape = } {v.shape = }')
print(f'{q = }')
#The key operation is to split the d_out dimension into num_heads and head_dim using view option
#(b, num_tokens, d_out) is reshaped to dimension (b, num_tokens, num_heads, head_dim)
#d_out = num_heads * head_dim
q = q.view(batch,num_tokens,num_heads,head_dim)
k = k.view(batch,num_tokens,num_heads,head_dim)
v = v.view(batch,num_tokens,num_heads,head_dim)
print(f'{q.shape = } {k.shape = } {v.shape = }')

#Now transpose (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
#Logically now we have num_heads of num_tokens x head_dim...
#Now we have num_heads (multiple head ) of num_tokens x head_dim
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
print(f'{q.shape = } {k.shape = } {v.shape = }')

attn_scores = q @ k.transpose(2, 3)
print(f'{attn_scores.shape = } {attn_scores = }')


context_vec = (attn_scores @ v)
print(f'{context_vec.shape = }')

#Now we are transposing (b, num_heads, num_tokens, head_dim) back to (b, num_tokens, num_heads, head_dim)
context_vec = context_vec.transpose(1, 2) 
print(f'{context_vec.shape = }')

context_vec = context_vec.contiguous().view(batch, num_tokens, d_out)
print(f'{attn_scores.shape = } {context_vec.shape = }')
print (f'{context_vec.shape = } {context_vec = }')

context_vec.shape = torch.Size([1, 2, 6]) context_vec = tensor([[[0.7713, 1.1682, 0.6141, 0.8472, 0.9745, 0.6424], [0.3327, 0.6094, 0.1885, 0.2612, 0.6739, 0.3949]]])

Studying "Build a Large Language Model (From Scratch) by Sebastian Raschka. I was trying to evaluate "Multi Head Simple Implementation" and Alternative Implmentation as mentioned in the book using following code. As per my understanding both method should give me same output. But in my following code values are different (i.e context_vec_Final in Simple implementation and context_vec in Alternative implementation should be same). Can you please help me to correct my mistake. For simplicity purpose, I discarded Causal Attention in this example.

#Simple Implementation
import torch
import torch.nn as nn
torch.manual_seed(1223)

batch = 1
num_tokens = 2
d_out  = 6
num_heads = 3
head_dim = d_out // num_heads

out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)
print (f'{batch = } {num_tokens = } {d_out = } {num_heads = } {head_dim = }')
q = torch.rand(batch,num_tokens,d_out)
k = torch.rand(batch,num_tokens,d_out)
v = torch.rand(batch,num_tokens,d_out)
print(f'{q.shape = } {k.shape = } {v.shape = }')
context_vec = []
for _ in range(num_heads) :
    attn_scores1 = q @ k.transpose(1, 2)
    context_vec1 = attn_scores1 @ v
    context_vec.append(context_vec1)

context_vec_Final = out_proj(torch.cat (context_vec, dim = -1))
print(f'{context_vec_Final.shape = } {context_vec_Final = }')

context_vec_Final.shape = torch.Size([1, 2, 18]) context_vec_Final = tensor([[[ 0.4481, 0.1600, 1.1673, -0.0284, -0.4112, -0.0324, -0.8480, -0.6424, 1.0077, 0.2438, -0.2781, 0.6076, -0.4314, 0.4139, 0.6364, -0.7987, -0.1409, 1.0451], [ 0.6571, 0.2290, 1.5520, -0.0847, -0.6139, -0.0287, -1.0902, -0.9133, 1.2716, 0.2759, -0.2912, 0.8003, -0.5354, 0.4829, 0.8830, -0.9920, -0.1232, 1.3950]]], grad_fn=)

Alternative Variation

import torch
torch.manual_seed(1223)

batch = 1
num_tokens = 2
d_out  = 6
num_heads = 3
head_dim = d_out // num_heads

print (f'{batch = } {num_tokens = } {d_out = } {num_heads = } {head_dim = }')
q = torch.rand(batch,num_tokens,d_out)
k = torch.rand(batch,num_tokens,d_out)
v = torch.rand(batch,num_tokens,d_out)
print(f'{q.shape = } {k.shape = } {v.shape = }')
print(f'{q = }')
#The key operation is to split the d_out dimension into num_heads and head_dim using view option
#(b, num_tokens, d_out) is reshaped to dimension (b, num_tokens, num_heads, head_dim)
#d_out = num_heads * head_dim
q = q.view(batch,num_tokens,num_heads,head_dim)
k = k.view(batch,num_tokens,num_heads,head_dim)
v = v.view(batch,num_tokens,num_heads,head_dim)
print(f'{q.shape = } {k.shape = } {v.shape = }')

#Now transpose (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
#Logically now we have num_heads of num_tokens x head_dim...
#Now we have num_heads (multiple head ) of num_tokens x head_dim
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
print(f'{q.shape = } {k.shape = } {v.shape = }')

attn_scores = q @ k.transpose(2, 3)
print(f'{attn_scores.shape = } {attn_scores = }')


context_vec = (attn_scores @ v)
print(f'{context_vec.shape = }')

#Now we are transposing (b, num_heads, num_tokens, head_dim) back to (b, num_tokens, num_heads, head_dim)
context_vec = context_vec.transpose(1, 2) 
print(f'{context_vec.shape = }')

context_vec = context_vec.contiguous().view(batch, num_tokens, d_out)
print(f'{attn_scores.shape = } {context_vec.shape = }')
print (f'{context_vec.shape = } {context_vec = }')

context_vec.shape = torch.Size([1, 2, 6]) context_vec = tensor([[[0.7713, 1.1682, 0.6141, 0.8472, 0.9745, 0.6424], [0.3327, 0.6094, 0.1885, 0.2612, 0.6739, 0.3949]]])

Share Improve this question edited Mar 18 at 10:40 KJG asked Mar 18 at 9:25 KJGKJG 11510 bronze badges 0
Add a comment  | 

1 Answer 1

Reset to default 1

So there are a few problems:

  1. in the simple implementation you pass the result through a linear layer and in the alternative you don't, you have to either do it for both or none

  2. your dimensions don't match, the naming might be confusing but for the q,k,v of the simple implementation you should either you head_dim (and for the example reuse them for each head) or use d_out but then in the for loop access only the "results" of one head

  3. if you want to use the linear layer, you would also need to adjust it to nn.Linear(d_out, d_out) (for both)

  4. you have to ensure that you access the same parts of q,k,v, depending on your correction in 2, you would have to either make sure you access the q,k,v of the correct head in simple or you would have to sample only head_dim for the alternative and then repeat it num_heads times

So to summarize, the code that produces identical results:

import torch
import torch.nn as nn
torch.manual_seed(1223)

batch = 1
num_tokens = 2
d_out  = 6
num_heads = 3
head_dim = d_out // num_heads

out_proj = nn.Linear(d_out, d_out)
print (f'{batch = } {num_tokens = } {d_out = } {num_heads = } {head_dim = }')
q = torch.rand(batch,num_tokens,head_dim)
k = torch.rand(batch,num_tokens,head_dim)
v = torch.rand(batch,num_tokens,head_dim)
print(f'{q.shape = } {k.shape = } {v.shape = }')
context_vec = []
for _ in range(num_heads) :
    attn_scores1 = q @ k.transpose(1, 2)
    context_vec1 = attn_scores1 @ v
    context_vec.append(context_vec1)

context_vec_Final = out_proj(torch.cat (context_vec, dim = -1))
print(f'{context_vec_Final.shape = } {context_vec_Final = }')
import torch
torch.manual_seed(1223)

batch = 1
num_tokens = 2
d_out  = 6
num_heads = 3
head_dim = d_out // num_heads

out_proj = nn.Linear(d_out, d_out)
print (f'{batch = } {num_tokens = } {d_out = } {num_heads = } {head_dim = }')
q = torch.rand(batch,num_tokens,head_dim).repeat(1,1,num_heads)
k = torch.rand(batch,num_tokens,head_dim).repeat(1,1,num_heads)
v = torch.rand(batch,num_tokens,head_dim).repeat(1,1,num_heads)
print(f'{q.shape = } {k.shape = } {v.shape = }')
print(f'{q = }')
#The key operation is to split the d_out dimension into num_heads and head_dim using view option
#(b, num_tokens, d_out) is reshaped to dimension (b, num_tokens, num_heads, head_dim)
#d_out = num_heads * head_dim
q = q.view(batch,num_tokens,num_heads,head_dim)
k = k.view(batch,num_tokens,num_heads,head_dim)
v = v.view(batch,num_tokens,num_heads,head_dim)
print(f'{q.shape = } {k.shape = } {v.shape = }')

#Now transpose (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
#Logically now we have num_heads of num_tokens x head_dim...
#Now we have num_heads (multiple head ) of num_tokens x head_dim
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
print(f'{q.shape = } {k.shape = } {v.shape = }')

attn_scores = q @ k.transpose(2, 3)
print(f'{attn_scores.shape = } {attn_scores = }')


context_vec = (attn_scores @ v)
print(f'{context_vec.shape = }')

#Now we are transposing (b, num_heads, num_tokens, head_dim) back to (b, num_tokens, num_heads, head_dim)
context_vec = context_vec.transpose(1, 2) 
print(f'{context_vec.shape = }')

context_vec = out_proj(context_vec.contiguous().view(batch, num_tokens, d_out))
print(f'{attn_scores.shape = } {context_vec.shape = }')
print (f'{context_vec.shape = } {context_vec = }')

与本文相关的文章

发布评论

评论列表(0)

  1. 暂无评论