paper summary: “LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking”
arxiv: https://arxiv.org/abs/2204.08387
key points
- use linear embedding for image embedding instead of a dedicated CNN based network, making the model more simple
- use three pretraining tasks: MLM, MIM, WPA
I think this work is strongly based on its previous work LayoutLMv2, so if you haven’t checked it out, I recommend doing so on my previous post. If you have questions on architecture details, perhaps you may find hints from LayoutLMv2 paper.
Model Architecture
The structure is a group of transformer encoder layers where text tokens and image tokens are provided as input.
text embedding
use off the shelf OCR to get text and its 2d position.
word embedding + position embedding
word embedding: use word embedding matrix from ROBERTA
position embedding: 1d position + 2d position embedding
instead of word-level layout positions, use segment-level layout positions. words in the same segment will share 2d position.
image embedding
divide image into patches, forward through linear projection, and then feed to multimodal transformer. the following is processing details
- resize image to HxW. (H,W,C) shape
- split image to PxP sized patches
- each patch is linear projected to D dimensions and flattened. gives (HW/P², D) shape
- add learnable 1d position embedding for each patch. 2d position embeddings was tested but no improvements were found.
insert semantic 1d position and spatial 2d relative position as bias terms in self-attention networks as did in layoutlmv2.
Pretraining objectives
Masked Language Modeling(MLM)
mask 30% of tokens with span masking. span lengths drawn from possion distribution.
Masked Image Modeling(MIM)
borrowed from BEiT
randomly mask 40% image tokens with blockwise masking strategy
try to reconstruct masked image tokens.
label of image token come from an image tokenizer
MIM facilitates learning high level layout structures
Word Patch Alignment(WPA)
even though we do MLM, MIM there is still lack of learning matching a text word to its corresponding area in the image. To make the model learn this “alignment”, WPA is used.
predict whether corresponding image patch of a text word is masked. binary classification. predict either “aligned” or “unaligned”.
assign “aligned” label to unmasked(in MLM) and visually unmasked tokens.
assign “unaligned” label to unmasked(in MLM) but visually masked tokens.
for MLM masked tokens, exclude them from WPA loss calculation. Remember, MLM, MIM, WPA is applied simulatenously so we need to be carefule of such cases.
use binary cross entropy loss.
Model Variants and Details
two variants: base & large
base: 12 layer encoder. 12 head self attention. hidden size=768, intermediate of feed forward network=3072
large: 24 layer encoder, 16 head self attention, hidden size=1024, intermediate size=4096
text input tokenized with BPE with maximum sequence length 512.
add [CLS] and [SEP] at beginning and end of text sequence.
bounding box coordinates for special tokens are all zero.
image embedding parameters
- input size(chw): 3x224x224
- P=16
- M=196 (where M=HW/P² in image slicing)
use distributed training, mixed precision training, gradient accumulation
follow CogView to change computation of attention
Pretraining
pretrain on large IIT-CDIP dataset. It has 11 million document images, 42 million pages. Use only 11 million pages.
no image augmentation.
weight initialized from pretrained RoBERTa.
image tokenizer initialized from pretrained image tokenizer in DiT. vocab size is 8192.
other params are random initialized.
use adam optimizer. batch size 2048, run for 500k steps.
Experiment
test on four downstream tasks
form understanding
test on FUNSD dataset. do entity labeling task.
layoutlmv3 achieves SOTA by a significatn margin
receit understanding
CORD dataset
layoutlmv3 achieves SOTA
document image classification
RVL-CDIP dataset. extract text and layout information using Microsoft OCR.
layoutlmv3 achieves better or comparable results than previous works with much smaller model size.
comparing with layoutlmv3 which uses a dedicated network for image embedding, it shows simple image embeddings is capable of handling image centric tasks.
document visual question answering
use DocVQA dataset.
although the DocVQA datasets original task is input document image along with question and output the answer, since the DocVQA doesn’t provide text boxes, the authors formalized this task as an extract QA problem where model is to predict start and end positions among the text tokens.
use Microsoft READ API to extract text and bounding boxes, and use heuristics to find the start and end positions of the answer.
layoutlmv3 shows better results than layoutlmv2.
document layout detection
while the previous 4 tasks were text-centric, this bonus task experiment is image-centric since it doesn’t receive any text tokens. only image tokens.
Use the Publaynet dataset which is object detection task with document images.
Use Layoutlmv3 as feature backbone for Cascade R-CNN detector.
results shows that using a pretrained layoutlmv3 as feature backbone and finetuning it outperforms others. this indicates that the multimodality of pretrained layoutlmv3 can benefit even in area where text tokens are not required.
I think this shows that the text and image multimodally pretrained layoutlmv3 has somehow learned to infer text information from image embeddings upto an extent, and this has benefited in this task experiment.
Ablation Study
- only using text tokens actually works quite well enough
- adding image tokens without doing any image token related pretraining task actually may be worse than using only text tokens.
- if using image tokens, adding image token related pretraining task (either MIM, WPA) is greatly helpful in training, and that’s when the model really shows its true potential
- as expected using all three pretraining tasks gives the best performance
comments & questions
- no mention of the definition of “segment” for sharing the 2d position information among texts.
- interesting use of “image tokenizer” from another work, and using it as a label for image token reconstruction.