Multi-head attention is a key innovation in the transformer architecture that allows the model to attend to different parts of the input sequence simultaneously and learn diverse types of information.
In multi-head attention, the [[self-attention]] mechanism is performed multiple times in parallel, each time with a different set of [[query]], [[key]], and [[value]] vectors. These parallel instances of attention are referred to as [[attention head|attention heads]]. In the original paper, 8 attention heads were used.
For each attention head, [[attention score|attention scores]] are calculated by taking the dot product between the query vector of the current position and the key vectors of all positions in the sequence.
The calculated dot products are divided by the square root of the dimension of the key vectors (this is known as scaled dot-product attention). This scaling helps prevent extremely large gradients during training.
The scaled attention scores for each position are passed through the softmax function, resulting in attention weights that sum up to 1 for each position. The value vectors are then weighted by these attention scores and summed up, producing an output representation for each position.
The output representations from all attention heads are concatenated along a specified dimension. This concatenated output is then linearly transformed through another set of learned weights to produce the final multi-head self-attention output.
The multi-head self-attention output is often followed by a [[feedforward neural network]] layer, which introduces further non-linearity and captures complex interactions between features.
In summary, multi-head self-attention allows the transformer model to process input sequences in parallel through multiple [[attention head|attention heads]]. Each head focuses on different relationships and patterns, and their outputs are combined to create a richer and more diverse representation of the data.
[[attention head]] < [[Hands-on LLMs]]/[[2 LLMs and Transformers]] > [[residual connection]]