Authors: Korbinian Pöppel, Maximilian Beck, Sepp Hochreiter
Abstract: While Transformers and other sequence-parallelizable neural network
architectures seem like the current state of the art in sequence modeling, they
specifically lack state-tracking capabilities. These are important for
time-series tasks and logical reasoning. Traditional RNNs like LSTMs and GRUs,
as well as modern variants like sLSTM do have these capabilities at the cost of
strictly sequential processing. While this is often seen as a strong
limitation, we show how fast these networks can get with our
hardware-optimization FlashRNN in Triton and CUDA, optimizing kernels to the
register level on modern GPUs. We extend traditional RNNs with a
parallelization variant that processes multiple RNNs of smaller hidden state in
parallel, similar to the head-wise processing in Transformers. To enable
flexibility on different GPU variants, we introduce a new optimization
framework for hardware-internal cache sizes, memory and compute handling. It
models the hardware in a setting using polyhedral-like constraints, including
the notion of divisibility. This speeds up the solution process in our
ConstrINT library for general integer constraint satisfaction problems (integer
CSPs). We show that our kernels can achieve 50x speed-ups over a vanilla
PyTorch implementation and allow 40x larger hidden sizes compared to our Triton
implementation. Our open-source kernels and the optimization library are
released here to boost research in the direction of state-tracking enabled RNNs
and sequence modeling: \url{https://github.com/NX-AI/flashrnn}
Source: http://arxiv.org/abs/2412.07752v1