REALM: Retrieval-Augmented Language Model Pre-Training
Guu et al. [Google Research]
arxiv.org/abs/2002.08909TL;DR:
добавили retreival-шаг перед MLM в предобучение, хорошо улучшили open domain question answering SOTA
Наконец-то почитал про REALM и оказалось что это ещё более интересная штука, чем я думал изначально. Основная идея простая: давайте помогать MLM дополнительным контекстом. Для этого будем использовать эмбеддинги BERT для поиска релевантных документов по нашему корпусу и использовать их в качестве контекста (см. картинку). При этом в качестве лосса берём взвешенное на предсказанные вероятности релевантности документа среднее по top8 документов. Трансформер делающий MLM и трансформер отвечающий за retreival - это две разные модели, но учатся вместе end-to-end. Зачем это нужно? Потому что это очень близко к ODQA-сетапу, когда мы просто хотим задать вопрос и получить ответ. И это действительно улучшает результаты T5 (которая сильно больше, чем REALM) с 35 до 40 пунктов.
Интересные детали.
1. Иногда дополнительного контекста для MLM не нужно (например, грамматическая конструкция) и для этого среди topk документов всегда есть
нулевой документ и его эмбеддинг тоже обучается.
1. Иногда нужные токены содержатся прямо в найденном документе, но мы бы хотели, чтобы моделька сама их предсказала, а не просто скопировала бы свой вход. Для этого все именованные сущности во вспомогательном документе (но не во вохде в MLM) типа дат и имён меняют на generic токены типа имя, дата. Их находят перед началом обучения с помощью готовой модели NER.
1. Иногда самый ближайший к MLM-инпуту документ содержит MLM-инпут, но без маски. Эти документы отсеиваются при ретриве.
1. В самом начале обучения BERT плохо ретривит (там после BERT есть доп линейный слой) и поэтому REALM может начать игнорировать вспомогательные документы. Для этого ретривер немножпо предобучается на задаче, где нужно найти документ по отрывку из этого документа.
Как ускоряли KNN:
KNN на нескольких миллионах документов будет работать очень долго. А нам нужно только top8. Для этого использовали построение индекса и приблизительный поиск.
Но тк моделька ретрива тоже обновляется с каждым батчом, индекс по-идее становится нерелевантным и его нужно обновлять. Вместо того, чтобы обновлять его каждый батч,
индекс обновляли каждые 500 итераций да ещё и асинхронно. Тк идея в том, что топ очень сильно не поменяется за 500 батчей.
По железу:
С одной стороны авторы говорят, что вся моделька поместится в
12Гб оперативки. Но с другой стороны они использовали
64 TPU для обучения и 16TPU для построения индекса. Вроде бы сопостовимо с BERT по объёмом, ничего особенно безумного, но кажется про времена, когда 8GPU было большой мощью уже в далёком прошлом.