Investigating the Function of RoPE in context windows
Problem
The task at hand involves applying the same relative positions to all segments of the data. In this scenario, we’ll explore a toy dataset in JSON Key-Value format. The model is exposed to all existing (Key, Value) pairs and is ultimately prompted to print the corresponding value for a given key. The challenge here is to ensure that the model considers all (Key, Value) pairs at the same relative positions.
Experiment 1
Preliminaries
The Unlimiformer paper presents a faster extension of the Memorizing Transformer paper. Unlike the latter, Unlimiformer maintains a single index instead of an index per head per layer. However, the model we’re examining is an encoder-decoder model where all key vectors are derived from the output of the last encoder layer. Therefore, the attention score (x_q^T W_q^T W_k x_e) is divided into an index (x_e) and a query to the vector store (x_q^T W_q^T W_k). This ensures that the index-to-query relationship remains consistent across different heads and attentions. Note that x_e represents the output hidden state of the last encoder layer, which is used to compute key vectors in an encoder-decoder model.
Unlimiformer for Decoder-Only Models
When dealing with decoder-only models, we encounter two challenges:
1. Different Inputs for Key and Value Computation
In decoder-only models, the input to compute keys and values comes from the output of the previous decoder layer, which varies for each layer. This necessitates having as many indexes as there are decoder layers.
2. RoPE Embeddings and Efficient Indexing
Recent open-source models utilize RoPE (Relative Position Embeddings), which provide relative positional information between key and query vectors. The attention score calculation now includes RoPE embeddings: (x_q^T W_q^T R(s) W_k x_k), where R(s) is a rotation matrix dependent on the relative distance (s) between the query and key tokens. To create an efficient index akin to Unlimiformer, we need to restrict the number of indexes to the number of decoder layers. In the attention score calculation, R(s) and x_k are parameters dependent on the key vector, making it desirable to store them in the index, while the rest of the computation pertains to vectorstore queries.
Examination of Unlimiformer’s Implementation for Decoder-Only Models
The Unlimiformer implementation makes two approximations, which are confirmed experimentally:
1. Constructing the Vectorstore Query
The query is calculated as follows: (R(m) * W_q * x_q)^T (W_k + Rotated(W_k)). Rotated(W_k) is a method for efficiently applying rotation to a vector, avoiding the need for matrix-matrix multiplication. The implementation approximates (cos = sin) since cos and sin are represented as vectors rather than trigonometric functions, as described in the code.
2. Applying Final Relative Position Indices
After retrieving the top-k closest hidden states and projecting them using W_k, Unlimiformer performs rotation uniformly, assuming that the starting retrieved key serves as the origin. However, an issue arises where the query is also assigned a relative position with the origin at the first generated token. This may affect the query token’s performance, favoring the middle retrieved key over those on either side.
High-Complexity Retrieval Methodology
To address these challenges, a complex series of steps is employed to convert the attention score into cosine similarity between a vector dependent on x_k and one independent of x_k. Additionally, the embedding dimension in the vector store is quadratically increased from the transformer’s embedding dimension to (attention-dim * embedding-dim).
Query and Key Values
- Query to vectorstore: Concatenate [x . [cos(s * theta_1), cos(s * theta_2), …, cos(s * theta_(a/2-1))]^T, x . [sin(s * theta_1), sin(s * theta_2), …, sin(s * theta_(a/2-1))]^T]
- Vectors in vectorstore: Concatenate [torch.sum(torch.cat(torch.split(q . W_k, 2, dim=0), dim=1), dim=0), …]
Derivation Details
Experiment-2
Is the RoPE corresponding to retrieved keys even necessary during generation?
##
Enjoy Reading This Article?
Here are some more articles you might like to read next: