group-telegram.com/boris_again/2992
Last Update:
The Pitfalls of Next-Token Prediction
Статья: https://arxiv.org/abs/2403.06963
Видео: https://www.youtube.com/watch?v=9V0bfZqT1Yo
Олды несомненно помнят, что в ранних seq2seq моделях, основанных на рекуррентных нейронных сетях, существовало два режима обучения: teacher-forcing, где на каждом шаге генерации в качестве входов использовались реальные токены, и другой режим с использованием токенов, предсказанных текущей версией модели. С появлением трансформеров и их параллельного обучения все стали использовать teacher-forcing. Авторы статьи возвращаются к этому вопросу.
🔹Задача
Авторы придумали простую синтетическую задачу: поиск пути между двумя вершинами в деревьях очень специфичной структуры, а именно в таких, где есть одна центральная вершина и несколько цепочек, исходящих из этой центральной вершины. Пример такого дерева (степень центральной вершины = 2, длина цепочек = 5):
8 ← 1 ← 5 ← 4 ← 3 → 0 → 2 → 6 → 7
Условия задачи:
— Степень центральной вершины и длина цепочек фиксированы для всех деревьев в обучающей и тестовой выборке.
— Путь всегда начинается в центральной вершине.
— Путь всегда заканчивается в одном из листьев.
Вход для задачи выглядит как случайно перемешанный набор рёбер дерева, плюс начало и конец пути (после "/"):
3 → 4 | 5 → 1 | 4 → 5 | 0 → 2 | 3 → 0 | 1 → 8 | 6 → 7 | 2 → 6 / 3 7
Выход выглядит как сам путь:
3 → 0 → 2 → 6 → 7
Эту задачу мы решаем какой-нибудь моделью, которая умеет работать с последовательностями, например трансформером или рекуррентной сетью в авторегрессионном режиме (генерация токенов слева направо, как в языковых моделях).
🔹Эмпирическая часть
— Авторегрессионные модели не справляются с решением этой задачи даже для деревьев с фиксированной структурой. Потому что сложно понять в какую сторону идти от центральной вершины.
— При развороте пути задача успешно решается авторегрессионными моделями. Это логично, потому что это гораздо проще: вы просто поднимаетесь по родителям, пока не найдёте центральную вершину.
— Если во время обучения маскировать уже сгенерированную часть пути, модели также успешно решают задачу. Это странно, потому что мы делаем задачу сложнее для модели, заставляя её генерировать весь путь сразу. Но каким-то образом на такой версии задачи модель учится, а на оригинальной — нет.
Я потратил пару вечеров и воспроизвёл это в Колабе: ссылка. Воспроизводил для 2-5 деревьев, то есть ровно таких, как в примере выше. Код писал с нуля, но опираясь на их Гитхаб. Всё получилось, как написано в статье: усложнение задачи приводит к возможности её выучивания. Технически это выглядит просто как маскирование части input_ids.
🔹Про предсказание следующего токена
Щепотка "соломенного чучела": распространенная критика языковых моделей состоит в том, что они являются лишь "стохастическими попугаями", способными только предсказывать следующий токен. Считается, что из-за этого они не могут эффективно планировать или исправлять ошибки.
Однако авторы статьи предполагают, что основная проблема не в механизме предсказания следующего токена как таковом. Проблема — в teacher forcing'е, то есть в том, что во время обучения у модели нет необходимости планировать и пытаться сформулировать решение в активациях. И ведь большинство современных моделей обучалось именно с использованием этого метода.
🔹Ограничения
— Эмпирическая часть работает при фиксированном наборе гиперпараметров, и сломав их, можно сломать 2 и 3 наблюдение. Это прежде всего оптимизационная задача. Однако ни у меня, ни у авторов не получилось сделать модель, которая была бы контрпримером для первого наблюдения.
— У авторов нет никакого теоретического обоснования наблюдений. Как нет и алгоритма, по которому сеть считает путь. Мне кажется, что тут есть простор для творчества и механистической интерпретации.