BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
Abstract:
Language Representation Model:
- BERT: Bidirectional Encoder Representations from Transformers.
- Self-supervised.
- Unlike recent language representation models (ELMo, OpenAI GPT).
- BERT is designed to pre-train deep bidirectional representations by jointly conditioning on both left and right context in all layers.
- BERT can be fine-tuned with just one additional output layer to create state-of-the art models for a wide range of tasks without substantial task-specific architecture modifications.
- BERT is conceptually simple and empirically powerful.
- BERT obtains new state-of-the-art results on 11 natural language processing tasks at that time.
Related Work:
Feature-Based / Fine-Tuning:
- Learning without Forgetting: https://arxiv.org/pdf/1606.09282.pdf
BERT:
BERT Model Architecture:
- E = Embedding
- Trm = Transformers
- T = Token?
BERT Input Representations:
Transformers: Attention Is All You Need:
- Attention Is All You Need: https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf
- The Annotated Transformer: http://nlp.seas.harvard.edu/2018/04/03/attention.html
Transformers: Details:
- Use Encoder only.
- Use WordPiece embedding, not word2vec. (not quite understand)
- Learn positional embedding, not positional encoding.
- Also use scaled dot-product attention: Q, K, V = (Queries, Key, Value)
- Self-Attention in Encoder: Q = K = V.
- Reference: Key-Value Memory Networks for Directly Reading Documents (not quite understand)
- Also use LayerNorm.
- Also use residual connection.
Pre-training Task #1: Masked LM
- Standard conditional language models can only be trained left-to-right or right-to-left, since bidirectional conditioning would allow each word to indirectly “see itself” in a multi-layered context.
- Solution: skip-gram model (word2vec) or cloze task or denoising autoencoders. (not quite understand)
- Masked Language Model: Mask some percentage of the input tokens at random, and then predicting only those masked tokens.
- Mask 15% of all WordPiece tokens in each sequence at random.
- Only predict the masked words rather than reconstructing the entire input.
- Example:
- Input: the man went to the [MASK1] . he bought a [MASK2] of milk.
- Labels: [MASK1] = store; [MASK2] = gallon
Pre-training Task #1: Downsides
- The first is that we are creating a mismatch between pre-training and fine tuning, since the [MASK] token is never seen during fine-tuning.
- To mitigate this, we do not always replace “masked” words with the actual [MASK] token.
- Rather than always replacing the chosen words with [MASK], the data generator will do the following:
- 80% of the time: Replace the word with the [MASK] token, e.g., my dog is hairy → my dog is [MASK]
- 10% of the time: Replace the word with a random word, e.g., my dog is hairy → my dog is apple
- 10% of the time: Keep the word unchanged, e.g., my dog is hairy → my dog is hairy. The purpose of this is to bias the representation towards the actual observed word.
- The second downside of using an MLM is that only 15% of tokens are predicted in each batch, which suggests that more pre-training steps may be required for the model to converge.
Pre-training Task #2: Next Sentence Prediction
- Many important downstream tasks such as Question Answering (QA) and Natural Language Inference (NLI) are based on understanding the relationship between two text sentences, which is not directly captured by language modeling.
- In order to train a model that understands sentence relationships, we pre-train a binarized next sentence prediction task that can be trivially generated from any monolingual corpus.
- Example #1:
- [CLS] the man went to [MASK] store [SEP] he bought a gallon [MASK] milk [SEP]
- Label: IsNext
- Example #2: (negative sampling)
- [CLS] the man [MASK] to the store [SEP] penguin [MASK] are flight ##less birds [SEP]
- Label: NotNext
Pre-training Tasks:
- We do not use traditional left-to-right or right-to-left language models to pre-train BERT.
- Instead, we pre-train BERT using two novel unsupervised prediction tasks.
- total_loss = masked_lm_loss + next_sentence_loss
- log_probs = tf.nn.log_softmax(logits, axis=-1)
- per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
- numerator = tf.reduce_sum(label_weights * per_example_loss)
- denominator = tf.reduce_sum(label_weights) + 1e-5
- loss = numerator / denominator
Pre-training Procedure:
- For the pre-training corpus we use the concatenation of BooksCorpus (800M words) and English Wikipedia (2,500M words).
- Training of BERTBASE was performed on 4 Cloud TPUs in Pod configuration (16 TPU chips total).
- Training of BERTLARGE was performed on 16 Cloud TPUs (64 TPU chips total).
- Each pre-training took 4 days to complete.
- Pricing? https://cloud.google.com/products/calculator/?hl=zh-TW
Fine-tuning Procedure:
- For sequence-level classification tasks, BERT fine-tuning is straightforward.
- In order to obtain a fixed-dimensional pooled representation of the input sequence, we take the final hidden state (i.e., the output of the Transformer) for the first token in the input, which by construction corresponds to the the special [CLS] word embedding.
- We denote this vector as C ∈ R^H.
- The only new parameters added during fine-tuning are for a classification layer W ∈ R^{K×H}, where K is the number of classifier labels.
- The label probabilities P ∈ R^K are computed with a standard softmax, P = softmax(CW^T).
Fine-tuning:
- For fine-tuning, most model hyperparameters are the same as in pre-training, with the exception of the batch size, learning rate, and number of training epochs.
- The dropout probability was always kept at 0.1.
- The optimal hyperparameter values are task-specific.
Experiments:
Sentence Pair Classification: MNLI:
- MNLI: Given a pair of sentences, the goal is to predict whether the second sentence is an entailment, contradiction, or neutral with respect to the first one.
- The only new parameters introduced during fine-tuning is a classification layer W ∈ R^{K×H}, where K is the number of labels.
- We compute a standard classification loss with C and W, i.e., log(softmax(CW^T)).
- GLUE leaderboard
- #3: BERT
- #2: MT-DNN (Microsoft)
- #1: Human
(Single Sentence Classification Tasks: similar to Sentence Pair Classification Tasks)
Question Answering: SQuAD v1.1:
- SQuAD v1.1: Given a question and a paragraph from Wikipedia containing the answer, the task is to predict the answer text span in the paragraph.
- The only new parameters learned during fine-tuning are a start vector S ∈ R^H and an end vector E ∈ R^H.
- The probability of word i being the start of the answer span is computed as a dot product between Ti and S followed by a softmax over all of the words in the paragraph.
- CoNLL 2003 Named Entity Recognition: Locate and classify named entity mentions in the unstructured text into predefined categories.
- For fine-tuning, we feed the final hidden representation Ti ∈ R^H for to each token i into a classification layer over the NER label set.
Ablation Studies:
Effect of Pre-training Tasks:
Effect of Model Size:
Effect of Number of Training Steps:
Feature-based Approach with BERT:
Conclusion:
- Recent empirical improvements due to transfer learning with language models have demonstrated that rich, unsupervised pre-training is an integral part of many language understanding systems.
- In particular, these results enable even low-resource tasks to benefit from very deep unidirectional architectures.
- Generalize these findings to deep bidirectional architectures, allowing the same pre-trained model to successfully tackle a broad set of NLP tasks.
- While the empirical results are strong, in some cases surpassing human performance, important future work is to investigate the linguistic phenomena that may or may not be captured by BERT.
References:
- Google Al Blog: https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html
- Source code: https://github.com/google-research/bert
- The Illustrated BERT, ELMo, and co. (How NLP Cracked Transfer Learning): https://jalammar.github.io/illustrated-bert/