remix logo

Hacker Remix

Multi-Token Attention

145 points by fzliu 1 day ago | 39 comments

kouteiheika 24 hours ago

This is another potential improvement to the transformer architecture from Facebook (the other one that comes to mind is this one from same authors: https://arxiv.org/abs/2405.18719), but note that it comes with a major problem that might not be obvious at first glance: it's just not usable in practice without a ton of work. It modifies the innards of the attention mechanism, so it is incompatible with Flash Attention (or any other optimized attention library), and you do not want to train anything beyond toy models without Flash Attention (the performance hit is just way too big).

There's pytorch's FlexAttention which could maybe make this practical, but currently it's just way too buggy.

jszymborski 23 hours ago

People familiar with exotic RNNs and improvements to LSTMs know this problem all too well. The moment your lstm isnt a bog standard lstm, it loses all the speed-ups from cuDNN and it becomes borderline unusable for anything but toy models.

tpurves 12 hours ago

These would be inherently temporary problems though right? If it became eventually clear that alternate methods were the way forward, NVDIA would be highly motivated to do the optimization work wouldn't they? Any new step functions that can forestall the asymptotic plateauing of AI progress are things they desperately need.

jszymborski 9 hours ago

That follows reason, but in practice I find that its often not the case. My suspicion is that it's hard to establish that your method is superior to another if, for example, it takes 10-100x the compute to train a model. This is largely in part due to the fact that machine learning is currently a deeply empirical field.

Nvidia isn't likely to start releasing updated firmware for an obscure architecture for which there is limited evidence of improvement, and even less adoption.

kouteiheika 2 hours ago

Indeed. Especially when a lot of papers are just using cherry-picked results that show some improvements just so they can publish something, but their method doesn't work that well when it comes in contact with reality (e.g. see the deluge of papers which claim to have come up with an optimizer better than AdamW), and when the majority of people are not even properly benchmarking their new methods wrt to the time overhead (no, it doesn't matter if your method achieves 1% better loss if it takes 10% longer to train, because if I'd trained for 10% longer without your method I'd get an even better loss; and don't even get me started on people not tuning their baselines).

I've been burnt way too many times by fancy new methods that claimed improvement, where I spent a ton of effort to implement them, and they ended up being poop.

Every person working in the field and pushing papers should read this blog post and apply what's written in it: https://kellerjordan.github.io/posts/muon/#discussion-solvin...

ssivark 8 hours ago

Check out The hardware lottery [1], which drove a lot of discussion a few years ago.

[1]: https://arxiv.org/abs/2009.06489

albertzeyer 17 hours ago

Why do you say FlexAttention is too buggy? I have heard about a lot of successful usages of it, and never heard about any such problems.

Also note, depending on your model dimensions and sequence lengths, often the attention computation plays only a minor role (maybe 10% overall or so), and the MLP computation dominates.

kouteiheika 13 hours ago

Last time I tried it I encountered both showstopper bugs (it was completely obviously broken) and subtle correctness bugs (it looked like it was working, but since I'm paranoid I have unit tests for everything and numerically the errors were too big compared to what you'd get with eager attention or Flash Attention), and it was too slow for my taste compared to Flash Attention so I just dropped it. And I wasn't even doing anything super exotic with it.

Maybe it's better now, but I'd still consider using FlexAttention without a corresponding unit test checking its accuracy against an equivalent eager implementation completely irresponsible.

bionhoward 1 day ago

How does this compare with Byte Latent Transformer [1]? This happens with convolution post-embedding while BLT happens with attention at embedding time?

1. https://ai.meta.com/research/publications/byte-latent-transf...

janalsncm 1 day ago

As I understand it, BLT uses a small nn to tokenize but doesn’t change the attention mechanism. MTA uses traditional BPE for tokenization but changes the attention mechanism. You could use both (latency be damned!)

bigdict 1 day ago

Sure, you can get better model performance by throwing more compute at the problem in different places. Does is it improve perf on an isoflop basis?

Reubend 1 day ago

It's a valid criticism that this method would increase compute requirements, but sometimes an improvement in the end result justifies the compute needed. For things like code generation in large datasets, many people would be willing to "pay" with more compute if the results were better. And this doesn't seem to require more memory bandwidth, so it could be particularly good for local models.

fabmilo 1 day ago

I read the paper and the results don't really convince me that is the case. But the problem still remains of being able to use information from different part of the model without squishing it to a single value with the softmax.

eightysixfour 1 day ago

That's... not always a given for SOTA sized models. When the ROI on more training stops, it is nice to have alternatives, whether that is RL-tuned reasoning models or alternative architectures that improve specific areas of weakness.

jwilber 1 day ago

There’s no one-size-fits-all answer here, but in my experience, for long contexts, perf for conv-based methods outperforms strictly attention-based methods. See evo2:

“With the current implementation of Evo2, we do not have the heavily optimized kernels in place for convolution operators like we do for attention layers in a model like llama2. Even with this shortcoming, we see that the benefit from including more convolutional layers makes up for the earlier stage of optimization at around the 64k context length. Beyond that point we see an improvement in performance even compared to a highly optimized transformer model.“

https://docs.nvidia.com/bionemo-framework/latest/models/evo2...

bob1029 1 day ago

So, we're proposing a multiplicative increase of something that already scales quadratically with the context size?

I think we've already got a bit of a bottleneck in terms of memory bandwidth utilization.

kadushka 1 day ago

If you have a bottleneck in terms of memory bandwidth utilization, this method is great - it would utilize the idle compute.

EGreg 1 day ago

LLaMa 3 already has RoPE encoding which can handle arbitrarily long contexts (within reason)

https://arxiv.org/abs/2104.09864

The difference RoPE makes vs traditional positional encoding is that you just care about relative distances between tokens, and we can attenuate the attention over great distances.

Instead of making the model look at every token in the entire sequence all at once (which gets expensive fast), you can break the text into logical chunks—like sentences or paragraphs—and run self-attention within each chunk. That keeps things efficient while still capturing local meaning. Then, for each chunk, you create a summary—either by pooling or using a small learned head—and pass those summaries into a second layer of attention that operates on a much smaller scale. This gives you higher-level context across the document, kind of like moving from sentences to sections to the whole thing. Optionally, you can even send that higher-level context back down to influence the lower layers. This approach shows up in models like Longformer and BigBird (which use attention windows), hierarchical models (like HANs), and newer architectures like RetNet and Mamba that compress information over time or scale. RoPE fits neatly into this by helping each chunk handle relative positions more naturally.

RoPE is kind of perfect for this setup because it handles relative positions directly in the attention mechanism, which means each chunk can still understand the order and spacing of tokens without relying on fixed position embeddings. It’s especially useful when you're working with long sequences or chunked inputs, because it doesn’t care where the chunk is in the overall document—it just cares about how tokens relate to each other within that chunk. RoPE also makes it easier for models to generalize to longer inputs than they were trained on, since the rotational math behind it naturally extends beyond the original context window. Plus, because it's baked into the dot product itself, it adds no extra memory or computation, and plays well with hierarchical or multi-scale attention setups. Basically, it’s a clean, efficient way to inject positional awareness that doesn’t break when you start slicing things up.

PS: LLaMA's RoPE may be a bit off but it still works great: https://discuss.huggingface.co/t/is-llama-rotary-embedding-i...

cma 1 day ago

> allowing nearby queries and keys to affect each other's attention weights for more precise attention

If it is only nearby tokens it is multiplicative by a constant right? Not making it cubic scaling with context length or anything.

Deepseek got a training performance increase with two tokens at a time, though it doesn't go into the final model inference like this. They did say it can be used for speculative decode to reduce inference costs though.

They may get away with less attention heads with this new approach too.

jgalt212 1 day ago

Maybe Sam was right about needing one trillion dollars!