Studying "Build a Large Language Model (From Scratch) by Sebastian Raschka. I was trying to evaluate "Multi Head Simple Implementation" and Alternative Implmentation as mentioned in the book using following code. As per my understanding both method should give me same output. But in my following code values are different (i.e context_vec_Final in Simple implementation and context_vec in Alternative implementation should be same). Can you please help me to correct my mistake. For simplicity purpose, I discarded Causal Attention in this example.
#Simple Implementation
import torch
import torch.nn as nn
torch.manual_seed(1223)
batch = 1
num_tokens = 2
d_out = 6
num_heads = 3
head_dim = d_out // num_heads
out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)
print (f'{batch = } {num_tokens = } {d_out = } {num_heads = } {head_dim = }')
q = torch.rand(batch,num_tokens,d_out)
k = torch.rand(batch,num_tokens,d_out)
v = torch.rand(batch,num_tokens,d_out)
print(f'{q.shape = } {k.shape = } {v.shape = }')
context_vec = []
for _ in range(num_heads) :
attn_scores1 = q @ k.transpose(1, 2)
context_vec1 = attn_scores1 @ v
context_vec.append(context_vec1)
context_vec_Final = out_proj(torch.cat (context_vec, dim = -1))
print(f'{context_vec_Final.shape = } {context_vec_Final = }')
context_vec_Final.shape = torch.Size([1, 2, 18]) context_vec_Final = tensor([[[ 0.4481, 0.1600, 1.1673, -0.0284, -0.4112, -0.0324, -0.8480, -0.6424, 1.0077, 0.2438, -0.2781, 0.6076, -0.4314, 0.4139, 0.6364, -0.7987, -0.1409, 1.0451], [ 0.6571, 0.2290, 1.5520, -0.0847, -0.6139, -0.0287, -1.0902, -0.9133, 1.2716, 0.2759, -0.2912, 0.8003, -0.5354, 0.4829, 0.8830, -0.9920, -0.1232, 1.3950]]], grad_fn=)
Alternative Variation
import torch
torch.manual_seed(1223)
batch = 1
num_tokens = 2
d_out = 6
num_heads = 3
head_dim = d_out // num_heads
print (f'{batch = } {num_tokens = } {d_out = } {num_heads = } {head_dim = }')
q = torch.rand(batch,num_tokens,d_out)
k = torch.rand(batch,num_tokens,d_out)
v = torch.rand(batch,num_tokens,d_out)
print(f'{q.shape = } {k.shape = } {v.shape = }')
print(f'{q = }')
#The key operation is to split the d_out dimension into num_heads and head_dim using view option
#(b, num_tokens, d_out) is reshaped to dimension (b, num_tokens, num_heads, head_dim)
#d_out = num_heads * head_dim
q = q.view(batch,num_tokens,num_heads,head_dim)
k = k.view(batch,num_tokens,num_heads,head_dim)
v = v.view(batch,num_tokens,num_heads,head_dim)
print(f'{q.shape = } {k.shape = } {v.shape = }')
#Now transpose (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
#Logically now we have num_heads of num_tokens x head_dim...
#Now we have num_heads (multiple head ) of num_tokens x head_dim
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
print(f'{q.shape = } {k.shape = } {v.shape = }')
attn_scores = q @ k.transpose(2, 3)
print(f'{attn_scores.shape = } {attn_scores = }')
context_vec = (attn_scores @ v)
print(f'{context_vec.shape = }')
#Now we are transposing (b, num_heads, num_tokens, head_dim) back to (b, num_tokens, num_heads, head_dim)
context_vec = context_vec.transpose(1, 2)
print(f'{context_vec.shape = }')
context_vec = context_vec.contiguous().view(batch, num_tokens, d_out)
print(f'{attn_scores.shape = } {context_vec.shape = }')
print (f'{context_vec.shape = } {context_vec = }')
context_vec.shape = torch.Size([1, 2, 6]) context_vec = tensor([[[0.7713, 1.1682, 0.6141, 0.8472, 0.9745, 0.6424], [0.3327, 0.6094, 0.1885, 0.2612, 0.6739, 0.3949]]])
Studying "Build a Large Language Model (From Scratch) by Sebastian Raschka. I was trying to evaluate "Multi Head Simple Implementation" and Alternative Implmentation as mentioned in the book using following code. As per my understanding both method should give me same output. But in my following code values are different (i.e context_vec_Final in Simple implementation and context_vec in Alternative implementation should be same). Can you please help me to correct my mistake. For simplicity purpose, I discarded Causal Attention in this example.
#Simple Implementation
import torch
import torch.nn as nn
torch.manual_seed(1223)
batch = 1
num_tokens = 2
d_out = 6
num_heads = 3
head_dim = d_out // num_heads
out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)
print (f'{batch = } {num_tokens = } {d_out = } {num_heads = } {head_dim = }')
q = torch.rand(batch,num_tokens,d_out)
k = torch.rand(batch,num_tokens,d_out)
v = torch.rand(batch,num_tokens,d_out)
print(f'{q.shape = } {k.shape = } {v.shape = }')
context_vec = []
for _ in range(num_heads) :
attn_scores1 = q @ k.transpose(1, 2)
context_vec1 = attn_scores1 @ v
context_vec.append(context_vec1)
context_vec_Final = out_proj(torch.cat (context_vec, dim = -1))
print(f'{context_vec_Final.shape = } {context_vec_Final = }')
context_vec_Final.shape = torch.Size([1, 2, 18]) context_vec_Final = tensor([[[ 0.4481, 0.1600, 1.1673, -0.0284, -0.4112, -0.0324, -0.8480, -0.6424, 1.0077, 0.2438, -0.2781, 0.6076, -0.4314, 0.4139, 0.6364, -0.7987, -0.1409, 1.0451], [ 0.6571, 0.2290, 1.5520, -0.0847, -0.6139, -0.0287, -1.0902, -0.9133, 1.2716, 0.2759, -0.2912, 0.8003, -0.5354, 0.4829, 0.8830, -0.9920, -0.1232, 1.3950]]], grad_fn=)
Alternative Variation
import torch
torch.manual_seed(1223)
batch = 1
num_tokens = 2
d_out = 6
num_heads = 3
head_dim = d_out // num_heads
print (f'{batch = } {num_tokens = } {d_out = } {num_heads = } {head_dim = }')
q = torch.rand(batch,num_tokens,d_out)
k = torch.rand(batch,num_tokens,d_out)
v = torch.rand(batch,num_tokens,d_out)
print(f'{q.shape = } {k.shape = } {v.shape = }')
print(f'{q = }')
#The key operation is to split the d_out dimension into num_heads and head_dim using view option
#(b, num_tokens, d_out) is reshaped to dimension (b, num_tokens, num_heads, head_dim)
#d_out = num_heads * head_dim
q = q.view(batch,num_tokens,num_heads,head_dim)
k = k.view(batch,num_tokens,num_heads,head_dim)
v = v.view(batch,num_tokens,num_heads,head_dim)
print(f'{q.shape = } {k.shape = } {v.shape = }')
#Now transpose (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
#Logically now we have num_heads of num_tokens x head_dim...
#Now we have num_heads (multiple head ) of num_tokens x head_dim
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
print(f'{q.shape = } {k.shape = } {v.shape = }')
attn_scores = q @ k.transpose(2, 3)
print(f'{attn_scores.shape = } {attn_scores = }')
context_vec = (attn_scores @ v)
print(f'{context_vec.shape = }')
#Now we are transposing (b, num_heads, num_tokens, head_dim) back to (b, num_tokens, num_heads, head_dim)
context_vec = context_vec.transpose(1, 2)
print(f'{context_vec.shape = }')
context_vec = context_vec.contiguous().view(batch, num_tokens, d_out)
print(f'{attn_scores.shape = } {context_vec.shape = }')
print (f'{context_vec.shape = } {context_vec = }')
context_vec.shape = torch.Size([1, 2, 6]) context_vec = tensor([[[0.7713, 1.1682, 0.6141, 0.8472, 0.9745, 0.6424], [0.3327, 0.6094, 0.1885, 0.2612, 0.6739, 0.3949]]])
Share Improve this question edited Mar 18 at 10:40 KJG asked Mar 18 at 9:25 KJGKJG 11510 bronze badges 01 Answer
Reset to default 1So there are a few problems:
in the simple implementation you pass the result through a linear layer and in the alternative you don't, you have to either do it for both or none
your dimensions don't match, the naming might be confusing but for the q,k,v of the simple implementation you should either you head_dim (and for the example reuse them for each head) or use d_out but then in the for loop access only the "results" of one head
if you want to use the linear layer, you would also need to adjust it to nn.Linear(d_out, d_out) (for both)
you have to ensure that you access the same parts of q,k,v, depending on your correction in 2, you would have to either make sure you access the q,k,v of the correct head in simple or you would have to sample only head_dim for the alternative and then repeat it num_heads times
So to summarize, the code that produces identical results:
import torch
import torch.nn as nn
torch.manual_seed(1223)
batch = 1
num_tokens = 2
d_out = 6
num_heads = 3
head_dim = d_out // num_heads
out_proj = nn.Linear(d_out, d_out)
print (f'{batch = } {num_tokens = } {d_out = } {num_heads = } {head_dim = }')
q = torch.rand(batch,num_tokens,head_dim)
k = torch.rand(batch,num_tokens,head_dim)
v = torch.rand(batch,num_tokens,head_dim)
print(f'{q.shape = } {k.shape = } {v.shape = }')
context_vec = []
for _ in range(num_heads) :
attn_scores1 = q @ k.transpose(1, 2)
context_vec1 = attn_scores1 @ v
context_vec.append(context_vec1)
context_vec_Final = out_proj(torch.cat (context_vec, dim = -1))
print(f'{context_vec_Final.shape = } {context_vec_Final = }')
import torch
torch.manual_seed(1223)
batch = 1
num_tokens = 2
d_out = 6
num_heads = 3
head_dim = d_out // num_heads
out_proj = nn.Linear(d_out, d_out)
print (f'{batch = } {num_tokens = } {d_out = } {num_heads = } {head_dim = }')
q = torch.rand(batch,num_tokens,head_dim).repeat(1,1,num_heads)
k = torch.rand(batch,num_tokens,head_dim).repeat(1,1,num_heads)
v = torch.rand(batch,num_tokens,head_dim).repeat(1,1,num_heads)
print(f'{q.shape = } {k.shape = } {v.shape = }')
print(f'{q = }')
#The key operation is to split the d_out dimension into num_heads and head_dim using view option
#(b, num_tokens, d_out) is reshaped to dimension (b, num_tokens, num_heads, head_dim)
#d_out = num_heads * head_dim
q = q.view(batch,num_tokens,num_heads,head_dim)
k = k.view(batch,num_tokens,num_heads,head_dim)
v = v.view(batch,num_tokens,num_heads,head_dim)
print(f'{q.shape = } {k.shape = } {v.shape = }')
#Now transpose (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
#Logically now we have num_heads of num_tokens x head_dim...
#Now we have num_heads (multiple head ) of num_tokens x head_dim
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
print(f'{q.shape = } {k.shape = } {v.shape = }')
attn_scores = q @ k.transpose(2, 3)
print(f'{attn_scores.shape = } {attn_scores = }')
context_vec = (attn_scores @ v)
print(f'{context_vec.shape = }')
#Now we are transposing (b, num_heads, num_tokens, head_dim) back to (b, num_tokens, num_heads, head_dim)
context_vec = context_vec.transpose(1, 2)
print(f'{context_vec.shape = }')
context_vec = out_proj(context_vec.contiguous().view(batch, num_tokens, d_out))
print(f'{attn_scores.shape = } {context_vec.shape = }')
print (f'{context_vec.shape = } {context_vec = }')