Text Generation with TED - Trainable Exponential Decay(s)

Abstract

In contrast to the comprehensive comparisons made by Transformers or the sequential dependency in RNNs, Trainable Exponential Decay(s) (TED) introduces a distinct approach that does not directly depend on past tokens. The core mechanism of TED considers each token only once, in isolation, to determine its decay rate, λ (lambda). Past tokens are then subjected to exponential decay, and their influence is summed and added to the current token to create rich representations. TED benefits from parallel hardware architecture during training, as the only dependency on past tokens is through addition, which can be performed at any time, in any order, within the context of a single layer. After training the model, one can observe that the influence of past tokens naturally decays over time, allowing for some tokens to become irrelevant and thus droppable from the context. Moreover, the well-established mathematical framework surrounding exponential decay enables precise prediction of the moment a particular token will become 'dead'. This observation opens up numerous optimization opportunities during inference. The simplest strategy involves maintaining the K most relevant tokens at any given time, ensuring constant memory requirements. Alternatively, a dynamic approach might involve dropping tokens whose relevance scores fall below a certain threshold. The relevance score is calculated by multiplying a token's vector magnitude ‖v‖ by its current decay factor at time 't'. The order of the past tokens that are kept in the context does not matter, thanks to the nature of exponential decay. This eliminates the need for the positional encoding mechanism, which is seen as problematic in Transformers. The architecture of TED follows that of a Transformer, with the difference that the attention module is replaced by an exponential decay module. TED can be implemented in PyTorch without special CUDA kernels. The code is available at: https://github.com/biuq/ted

Tokens decay over time and their influence is summed creating more useful representations.

Tiny Stories TED Model - Demo

This is Proof of Concept and it's intended to show that TED can generate plausible text. The model was trained using GTX 1060 for 20000 iterations.

It runs in your browser with pure unoptimized JavaScript code without any parallelization.

Download size: 25 MiB

Click Load below and wait for it to finish to try the model.

Initial prompt:

Sampling config:

Context config (all layers):


Below table contains dynamic information about number of tokens in a context per layer. Higher layers contain more useful representations, so the model learns to not forget them so quickly - predicts lower decay rate, thus the number of tokens in higher layers is higher. This effect can be seen if the Top-K for Context config is set to 0 - then Top-K strategy is not applied, and the relevance score threshold is greater than 0 - we allow to drop the tokens whose relevance score is below threshold.

Layer Index Context Size Avg. Context Size


For the first layer - layer 0 - we can pre-calculate the particular character lifetime ahead of time, because the character embeddings, once learnt, stay constant, thus the lambda (λ) parameter also is constant (per character). The calculation depends on the relevance score threshold set above in the Context config. The "lifetime" represents how many discrete steps into the future we need to take until the token is dropped from the context.

Format: [character] : [lifetime]

Architecture

TED architecture is similar to LLaMA [3] in some aspects:

However, there is no need for positional encodings and the attention layer is replaced with exponential decay module.

Multiple TED Layers can be stacked together.

Exponential Decay Module

The Exponential Decay Module is built on the premise that not all tokens carry equal weight as time progresses. By computing a decay factor that exponentially diminishes the influence of older tokens, the model can maintain a compact and relevant context, enabling more efficient processing and potentially enhancing performance on tasks that require a nuanced understanding of temporal dynamics. The core idea behind TED is to learn and subsequently predict the appropriate value for the decay constant in the exponential decay equation.

The model learns to predict λ (lambda).

The recipe

1. Start with input tokens.

tokens

2. Use trainable weight matrix to predict the negative value of lambda based on the input.

negative_lambdas = self.lambda_matrix(tokens).sigmoid().log()

3. Construct time matrix.

time = torch.ones(n, n).tril_(-1).cumsum_(-2)

4. Calculate decay factor.

decay_factor = (negative_lambdas.swapaxes(-1, -2) * time).exp().tril()

5. Use trainable matrix to calculate quantity.

quantities = self.quantity_matrix(tokens)

6. Calculate the influence (output) at every time step.

output = decay_factor @ quantities

7. Use trainable matrices to transform and gate the output as the last step.

y = self.output_matrix(F.silu(output)) * self.gate_matrix(tokens)

8. Putting it all together, the transformation performed by the Exponential Decay Module from input tokens X to Y can be expressed as:

Inference

During inference, the model must efficiently manage the context to maintain performance while conserving computational resources. Past tokens are kept in the context and are subject to an exponential decay, which reduces their relevance over time. This decay simulates the natural process of fading memory, where older information becomes less significant. To quantify the fading influence of each token, a relevance score is calculated. This score is derived by multiplying the magnitude of the token's vector representation by an exponential decay factor.

The relevance score.
With the relevance score established, one can employ various strategies to manage the context by selectively retaining only those tokens that are deemed crucial for the model's current state. These strategies include:

References

[1]: arXiv.2305.07759

[2]: huggingface.co/datasets/noanabeshima/TinyStoriesV2

[3]: arXiv.2302.13971