Reddit - r/MachineLearning

An Update on Matrix Recurrent Units, an Attention Alternative [R]

Background

I recently revisited my matrix recurrent units algorithm (the MRU), a novel linear-time sequence architecture I created as an alternative to attention. I explain it in depth at the repo, but the gist is the MRU works by transforming the embedding into an input state matrix, cumulatively multiplying the matrices across the sequence dimension to get the output state matrix, and then transforming the matrices back into a vector. In order to make the MRU efficient on DL hardware, I created a parallel scan by utilizing the operation's associativity.

About a year ago, I shared my project on Reddit (I've since renamed my account), with good results on the toy dataset shakespeare-char. A commenter asked the steps taken to bound the matrix states and another commenter found that training was inherently unstable when training on more comprehensive datasets.

Addressing Stability

I addressed these by experimenting with different methods to create the input state matrix. Originally, I simply reshaped the input vector into a matrix and added the identity. Since then, I've implemented the following methods:

  • Using the elements of the vector to fill a skew-symmetric matrix and using the matrix exponential or the Cayley Map to generate an orthogonal matrix
  • Filling LDU factors with elements from the vector and using an activation function on D to enforce a determinant of 1
  • Creating QR, by using the matrix exponential or Cayley map to create orthogonal matrix Q and filling the upper-triangular matrix R
  • Dividing by a determinant-correcting scalar factor, found by taking the determinant

I found that these fixes prevented loss spikes with varying tradeoffs. Interestingly, the scalar factor method led to worse results. Dividing the input states should only affect the output states by scaling them, indicating that the unscaled model was "cheating" on the toy dataset by learning a simple scalar decay pattern instead of more complex relationships.

Also, using the Cayley Map or matrix exponential to force the input states to be orthogonal surprisingly mostly prevented the model from learning information about the sequence, performing closer to the FFN than the Cayley QR method. The poor performance of orthogonal matrices indicates that the ability to learn shear transformations might be critical for the model. Possibly, rotations enforce dependence on the previous state, whereas shearing allows the model to adjust the state more independently of the previous state.

Experimental Results

Above are the train loss and validation loss on the shakespeare-char dataset for a small MRU LM, transformer, and FFN, with the embedding, state, key, and value size set to 256. The MRU LM has a single MRU layer and 4 MLPs, the transformer has a single attention layer and 4 MLPs, and the FFN only has 4 MLPs. I only used a single sequence-mixing layer in order to isolate the effect of the MRU.

Scaling to Larger Datasets

Finally, I moved to a larger dataset, trying to replicate https://huggingface.co/roneneldan/TinyStories-33M by training a baseline GPT-2 model and a model with attention replaced with the MRU. I ended up quitting the training runs early, but the loss curves seem to already conclusively show that the MRU performs worse on this task. For the creation of the MRU's input state matrices, I used the method of creating LDU factors, since it has the best performance.

Above is the validation loss for a transformer and a LM using MRU with the same hyperparameters and dimensions as the huggingface model card. The official TinyStories model was trained for 20 epochs, which corresponds to about 200k steps. In order to compare it to other linear-time models, I also briefly trained a linear transformer, using the algorithm described in Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention.

Conclusions and Future Work

I think that my research shows that the MRU likely doesn't work as a direct replacement for attention for generative language modeling, but I've already laid the groundwork for this algorithm. The MRU has dramatically different strengths and weaknesses compared to other algorithms such as attention, state space models, traditional RNNs, and fast weight programmers. It performs significantly more cumulative computation along the sequence (as opposed to the computation for each token being independent), is significantly more lightweight and hence faster, but also has a much lower storage capacity.

I believe that the MRU's alternative uses should still be explored. One usage of the MRU could be applying it to query and key vectors of attention. Similar to RoPE, it would rotate chunks of the vectors, but it would be able to rotate chunks in greater than two dimensions and with dynamic and non-commutative angles. This is one of many applications of the algorithm which I will continue to research, and I hope that others are interested in its applications as well.

If you're interested, reach out to me at mikayahlevi@gmail.com, Reddit, GitHub, or any other platform you can find me at.

Comments

No comments yet. Start the discussion.