The Dialogue Dodecathlon
https://parl.ai/projects/dodecadialogue/TL;DR
Собрали вместе 12 диалоговых датасетов, обучили на них transformer-based seq2seq модель в multitasking режиме и получили SOTA на всех 12 задачах.
Суть подхода
Две идеи для обучения генеративной диалоговой модели, работающей в open-domain сеттинге:
1. Для предобучения лучше использовать диалоговые данные (Reddit), а не произвольные текстовые (например, WebText, на котором учили GPT2).
2. Лучше учить модели в multi-tasking режиме:
- во-первых, удобно иметь одну универсальную модель, а не 10-20 специализированных;
- во-вторых, в теории обучение на одних задачах может помочь в достижении хороших результатах на других; поэтому среди рассмотренных в статье датасетов есть не только текстовые, но и QA-датасеты по картинкам.
Датасеты, рассмотренные в статье:
- ConvAI - кондишен на факты о персоне
- DailyDialog - обсуждение разных повседневных тем
- Wiz. of Wikipedia - кондишен на факты из википедии
- Empathetic Dialog - обсуждение жизненные ситуаций в дружелюбной (терапевтической) манере
- Cornell Movie - субтитры
- LIGHT - roll-play в выдуманных ситуациях
- ELI5 - вопросы и ответы в длинной форме
- Ubuntu - чат поддержки
- Twitter - twitter
-
pushshift.io Reddit - 2.2 миллиарда предложений с реддита на разные темы
- Image Chat - обсуждение персон на картинках
- IGC - вопросы и ответы по картинкам на разные темы
Результаты
В качестве бейзлайна взяли предобученную GPT2-модель в реализации hugging face'a (которую в статье почему-то называют BERT'ом).
В качестве конкурента использовали transformer-based seq2seq модель из своего ParlAI, в которую в частности добавили возможность кондишениться на фичи, извлеченные из картинок.
Вывод 1.
Лучшая стратегия претрейна для диалоговых моделей - обучаться на огромном датасетете
pushshifit.io Reddit (2.2 миллиарда предложений). Претрейн на твиттере и использование весов GPT2 существенно проигрывает по perplexity.
Для справки - свой seq2seq на Reddit'e они они учили две недели на 64-ех Nvidia V100.
Вывод 2.
Если после предобучения на Reddit'e доучивать модель на всех 12-задачах в multi-tasking режиме, уже получается универсальная модель, которая бьет почти все предыдущие task-specific модели по perplexity и специфичным метрикам типа BLEU / ROUGE / F1.
Вывод 3.
Наиболее результативным остается подход с finetune'ом модели на конкретную задачу: сначала идет предобучение на Reddit'a, потом обучение на всех задачах в multitasking режими, а потом finetune на конкретную задачу. При таком подходе получаются новые SOTA-модели для всех 12 задач.
Вывод 4.
Пожалуй, самый интересный результат статьи связан с так называемым Leave-One-Out Zero-Shot Performance: дообучаемся в multitasking режиме на всех датасетах, кроме одного, а тестируемся на оставшемся.
Авторы статьи показали, что и в этом случае метрики на новом датасете также очень приличные (если только не выкидывать Reddit из дообучения), что говорит о том, что multitasking-обучение способствует лучшему обобщению модели и "переносу знаний" на новые домены.