I am dealing with two embeddings, text and image both are last_hidden_state of transfomer models (bert and vit), so the shapes are (batch, seq, emd_dim)
. I want to feed text information to image using a cross attention mechanism and I was wondering whether this line of code will give me what I need:
cross_attention = nn.MultiheadAttention(embed_dim=768, num_heads=12, dropout=0.1)
attn_output, attn_output_weights = cross_attention(text_last, img_last, img_last)
I tried the provided code but I am not sure whether it is the right approach