Move the n-gram extraction into your Keras model!

2019-08-07

This post shows how to improve accuracy of character-level text classification with the deep learning framework Keras, using n-gram extraction inside the model:

Originally I published this post on the codecentric blog.

Our motivation

In a project on large-scale text classification, a colleague of mine significantly raised the accuracy of our Keras model by feeding it with bigrams and trigrams instead of single characters. For his experiments he could just modify the preprocessing and the model as he wished, but for production, it was much preferable to just replace the model being served by tensorflow and leave all other code unchanged. And that is what we did — move the bigram and trigram extraction into our neural network. In this blog post, I'll show you

of our approach.

The idea — n-gram extraction via convolution

Suppose we want to process the quote

"I'd far rather be happy than right any day"

of Douglas Adams. Instead of looking at the text as a sequence of characters

I'd far rather

a neural network may profit from looking at pairs of adjacent characters, that is, at the sequence of bigrams

I''dd__ffaarr__rraatthheer

or even trigrams or n-grams for n larger 3. To feed the neural network, we need to convert characters into numbers, for example, using the ASCII or UTF-8 codes,

733910032102971143211497116

Our bigrams then become sequences of pairs of numbers:

(73, 39)(39, 100)(100, 32)(32, 102)(102, 97)

If we encode these bigrams using the rule

(a, b) ↦ N · a + b,

where N is the size of our alphabet, we obtain a sequence of numbers again: in case N=256, this would be

73·256+39=1872739·256+100=10084100·256+32=25632

More generally, we can encode n-grams for arbitrary n using the rule

(a0, …,an-1) ↦ Nn-1 · a1 + Nn-2 · a2 + … + N · an-2 + an-1.

Here comes the key observation: with this encoding rule,

extracting n-grams becomes a convolution of the sequence of character codes with the kernel (1,N, …, Nn-1).

And this preprocessing step can easily be inserted as a first step into any character-level text-processing neural network.

The implementation

As a warm-up, let us implement the n-gram extraction as a convolution with NumPy. Given a NumPy array of character codes, the n-gram length n and the size of the alphabet N, the following function returns the sequence of encoded n-grams as an array:

1import numpy as np
2
3def ngrams_numpy(array, n, alphabet_size):
4    kernel = np.power(alphabet_size, range(0, n))
5    return np.convolve(array, kernel, mode='valid')

Next, how about the deep learning library Keras? Suppose we already have a working text-processing model whose input are (batches of) sequences of character codes. Then we can add bigram or n-gram extraction as a first layer using a lambda layer in one line. Indeed, given a batch of samples in form of a tensor of shape (batch_size, sample_length), the following function returns a batch of encoded bigrams in form of a tensor of shape (batch_size, sample_length - 1):

1from keras import layers
2
3def bigrams_lambda_layer(alphabet_size):
4    convolve = lambda x: x[:,:-1] + x[:,1:] * alphabet_size
5    return layers.Lambda(convolve)

However, lambda layers in Keras may cause problems when saving, loading or checkpointing the model. For further deployment of a model, for example with tensorflow serving, it might be better to avoid a lambda layer and to use a 1d-convolutional layer with fixed weights as follows:

 1import numpy as np
 2from keras import layers, backend
 3
 4def ngram_block(n, alphabet_size):
 5    def wrapped(inputs):
 6        layer = layers.Conv1D(1, n, use_bias=False, trainable=False)
 7        x = layers.Reshape((-1, 1))(inputs)
 8        x = layer(x)
 9        kernel = np.power(alphabet_size, range(0, n),
10                          dtype=backend.floatx())
11        layer.set_weights([kernel.reshape(n, 1, 1)])
12        return layers.Reshape((-1,))(x)
13
14    return wrapped

This function can be used like a layer,

1bigrams_tensor = ngram_block(2, alphabet_size)(input_tensor)

see also the experiment below. What this function does is

  • create a 1d-convolutional layer layer with one feature map, window size n, zero bias vector and frozen weights that are not changed during training,
  • reshape the input inputs, which is a tensor of shape (batch_size, sample_length), to a tensor x with shape (batch_size, sample_length, 1) (necessary because convolutional layers operate on sequences of vectors and not on sequences of scalars),
  • apply the convolutional layer to the reshaped input,
  • set the kernel of the convolutional layer and
  • reshape the output of the convolutional layer from (batch_size, sample_length, 1) to (batch_size, sample_length) again.

An experiment

Let us finally see how this idea works out for a classical test case, the 20 newsgroups dataset, where the task is to guess the topic of a given post from its text. We use a simple character-level convolutional network and see how n-gram extraction inside the model affects the classification accuracy and and training time.

To load the data, we use the datasets module of scikit-learn:

1from sklearn.datasets import fetch_20newsgroups
2
3data = fetch_20newsgroups(subset="train",
4                          remove=("headers", "footers", "quotes"))
5posts, topics = data["data"], data["target"]

Now posts is a list of newsgroup posts as strings, and topics is a list of numbers representing the respective newsgroup topics. For each topic, we have 350 to 600 samples:

histogram of n-grams

Note that this is way too little data for a character-level model to perform well. But let us try nevertheless. We apply some minimal preprocessing and

  • convert the characters to lower case,
  • filter out all characters that are not contained in our chosen ALPHABET,
  • replace the remaining characters by their index in the ALPHABET,
  • trim the sequence of indices to a fixed length MAX_LEN,
  • stack all those sequences in one large NumPy array:
 1import numpy as np
 2
 3ALPHABET = "abcdefghijklmnopqrstuvwxyz1234567890 !$#()-=+:;,.?/"
 4MAX_LEN = 1000
 5
 6def encode_sample(sample, index):
 7    indices = np.array([index[char] for char in sample
 8                                             if char in index])
 9    return np.resize(indices, MAX_LEN)
10
11index = {char: i + 1 for i, char in enumerate(ALPHABET)}
12X = np.stack([encode_sample(x.lower(), index) for x in posts])
13y = np.eye(20)[topics]

Now X is an array of shape (len(posts), MAX_LEN), and y is an array of shape (len(posts), 20) containing the one-hot encoded topics. As a baseline, we train a simple convolutional model:

 1from keras import layers, models, optimizers
 2
 3LAYER_PARAMS = [[64, 3, 3], [128, 3, 3]]
 4EMBEDDING_DIM = 16
 5
 6def build_model():
 7    inputs = layers.Input(shape=(MAX_LEN,))
 8    x = layers.Embedding(len(ALPHABET), EMBEDDING_DIM)(inputs)
 9    for filters, kernel_size, pool_size in LAYER_PARAMS:
10        x = layers.Conv1D(filters, kernel_size, activation="relu")(x)
11        x = layers.BatchNormalization()(x)
12        x = layers.SpatialDropout1D(0.15)(x)
13        x = layers.MaxPooling1D(pool_size)(x)
14    x = layers.GlobalAveragePooling1D()(x)
15    x = layers.Dense(20, activation="softmax")(x)
16    model = models.Model(inputs=inputs, outputs=x)
17    model.compile(optimizer=optimizers.Adadelta(),
18                  loss="categorical_crossentropy",
19                  metrics=["acc"])
20    return model
21
22model = build_model()
23history = model.fit(X, y, epochs=60, batch_size=20,
24                    validation_split=0.2)
25import json
26with open('baseline.json', 'w') as file:
27    json.dump(history.history, file)

The results are quite poor — the validation accuracy reaches just 60 percent:

training history, baseline

By careful tuning of hyperparameters, things certainly could be improved a bit. Now let us see how bigram and trigram extraction will affect performance of the model. Using the function ngram_block, we only need to insert the line x = ngram_block(n, size)(inputs) between the Input and Embedding layers in build_model as follows:

 1def build_ngram_model(n):
 2    inputs = layers.Input(shape=(MAX_LEN,))
 3    x = ngram_block(n, len(ALPHABET))(inputs)
 4    x = layers.Embedding(pow(len(ALPHABET), n),
 5                         n * EMBEDDING_DIM)(x)
 6    for filters, kernel_size, pool_size in LAYER_PARAMS:
 7        x = layers.Conv1D(filters, kernel_size, activation="relu")(x)
 8        x = layers.BatchNormalization()(x)
 9        x = layers.SpatialDropout1D(0.05 + 0.1 * n)(x)
10        x = layers.MaxPooling1D(pool_size)(x)
11    x = layers.GlobalAveragePooling1D()(x)
12    x = layers.Dense(20, activation="softmax")(x)
13    model = models.Model(inputs=inputs, outputs=x)
14    model.compile(
15        optimizer=optimizers.Adadelta(),
16        loss="categorical_crossentropy",
17        metrics=["acc"],
18    )
19    return model

We also raised the embedding dimension (because now we want to embed bigrams and trigrams instead of single characters) and use an adaptive spatial dropout rate. Let us see how the n-gram model performs:

1for n in range(1, 4):
2    build_ngram_model(n).fit(X, y, epochs=40,
3                             batch_size=20, validation_split=0.2)

The training histories show that the n-gram extraction yields a significant improvement:

training history, with n-grams

Indeed, the mean validation accuracy of the last 5 training epochs increased by more than 10 percent:

n123
mean validation accuracy0.57960.64010.7064

Limitations of the technique

Why did we stop at tri-grams in the experiment above? The reason is that we do not only encode the n-grams that occur in our samples, but reserve codings for all n-grams that could possibly occur. And that makes a huge difference when n is growing larger:

n12345
#(occuring n-grams)522,59647,203214,362551,904
#(potential n-grams)512,601132,6516,765,201345,025,251

And therefore, the embedding layer will need memory increasing exponentially with n. This is the reason why we stick to bigrams or trigrams. By the way, the numbers above where extracted as follows:

 1import pandas as pd
 2
 3def all_ngrams(n):
 4    length = MAX_LEN - n + 1
 5    ngrams = lambda x: set(zip(*[x[i:length + i]
 6                                 for i in range(0, n)]))
 7    return set().union(*[ngrams(x) for x in X])
 8
 9ns = range(1,6)
10alphabet_size = len(ALPHABET)
11counts = {'#(occuring n-grams)': [len(all_ngrams(n)) for n in ns],
12          '#(potential n-grams)': [pow(alphabet_size, n) for n in ns]}
13pd.DataFrame(counts, index = pd.Index(ns, name='n')).transpose()

Natural Language Processing — Einsteigen und Loslegen!