Telegram Group & Telegram Channel
Cut Your Losses in Large-Vocabulary Language Models
Статья: https://arxiv.org/abs/2411.09009
Рецензии: https://openreview.net/forum?id=E4Fk3YuG56
Код: https://github.com/apple/ml-cross-entropy

Статья про оптимизацию памяти при подсчёте функции потерь и её ближайших градиентов при обучении языковых моделей. Основной механизм — модифицированная реализация перекрёстной энтропии, Cut Cross-Entropy (CCE). Авторы берут ровно ту же оптимизацию, которая используется в Flash Attention (поблочное вычисление в кэше GPU), но применяют её к последнему слою и последнему софтмаксу.

Последний шаг при предсказании следующего токена — линейный слой и софтмакс. На каждом шаге генерации у нас есть вектор E с последнего слоя трансформера, мы умножаем его на матрицу C, получаем логиты в ℝ^|V|, для каждого логита считаем экспоненту и делим на сумму всех логитов из всего словаря. Так для каждого токена получаем вероятность, число в отрезке [0, 1]. Функция потерь при обучении — логарифм вероятности правильного токена (с минусом). Нас интересует только правильный токен, и только его логит нам нужен в числителе софтмакса. Логарифм в лоссе гасит экспоненту в числителе. Вычисление раскладывается на две части: вычисление логита правильного токена и вычисление слагаемого нормализации по E и всем столбцам C (логарифм суммы экспонент).

При обучении мы можем считать всё параллельно для всех токенов, поэтому там уже не вектор E, а матрица E.

Для вычисления логитов правильных токенов авторы выгружают блоки релевантных столбцов C и блоки E в кэш, считают там скалярное произведение, и выгружают назад в основную память только финальный результат. Вычисление логарифма суммы экспонент гораздо хитрее, как и вычисление его градиентов, но концепция та же.

Кроме собственно оптимизаций с кэшом, используется тот факт, что большинство значений на выходе софтмакса "плохие", то есть очень близкие к нулю. Из-за ограниченной точности чисел с плавающей точкой, "плохие" значения ни на что не влияют при использовании в слагаемом нормализации. И для них авторы предлагают просто не считать градиенты. Вторая оптимизация такого рода — сортировка словаря по средним логитам, чтобы токены с "плохими" логитами попадали в один блок, и можно было такие блоки полностью пропускать.

По классификации в прошлом посте — это AG метод, полезен только при обучении. Есть и древние альтернативы, да хотя бы иерархический софтмакс или адаптивный софтмакс.

Экспериментально для Мистраля Немо удалось уменьшить память на лосс+градиенты с 8 Гб до 1.3 Гб, что лучше, чем в Liger Kernel. Аналогичная (и иногда даже более существенная) экономия памяти есть и для других моделей.

Потрогать можно через их библиотеку и патчинг модели. То есть вы делаете вот такое:

from cut_cross_entropy.transformers import cce_patch

model = ...
model = cce_patch(model)


После этого лосс и градиенты будут считаться как в статье. Но логиты не будут возвращаться, потому что они не материализуются в принципе.



group-telegram.com/senior_augur/349
Create:
Last Update:

Cut Your Losses in Large-Vocabulary Language Models
Статья: https://arxiv.org/abs/2411.09009
Рецензии: https://openreview.net/forum?id=E4Fk3YuG56
Код: https://github.com/apple/ml-cross-entropy

Статья про оптимизацию памяти при подсчёте функции потерь и её ближайших градиентов при обучении языковых моделей. Основной механизм — модифицированная реализация перекрёстной энтропии, Cut Cross-Entropy (CCE). Авторы берут ровно ту же оптимизацию, которая используется в Flash Attention (поблочное вычисление в кэше GPU), но применяют её к последнему слою и последнему софтмаксу.

Последний шаг при предсказании следующего токена — линейный слой и софтмакс. На каждом шаге генерации у нас есть вектор E с последнего слоя трансформера, мы умножаем его на матрицу C, получаем логиты в ℝ^|V|, для каждого логита считаем экспоненту и делим на сумму всех логитов из всего словаря. Так для каждого токена получаем вероятность, число в отрезке [0, 1]. Функция потерь при обучении — логарифм вероятности правильного токена (с минусом). Нас интересует только правильный токен, и только его логит нам нужен в числителе софтмакса. Логарифм в лоссе гасит экспоненту в числителе. Вычисление раскладывается на две части: вычисление логита правильного токена и вычисление слагаемого нормализации по E и всем столбцам C (логарифм суммы экспонент).

При обучении мы можем считать всё параллельно для всех токенов, поэтому там уже не вектор E, а матрица E.

Для вычисления логитов правильных токенов авторы выгружают блоки релевантных столбцов C и блоки E в кэш, считают там скалярное произведение, и выгружают назад в основную память только финальный результат. Вычисление логарифма суммы экспонент гораздо хитрее, как и вычисление его градиентов, но концепция та же.

Кроме собственно оптимизаций с кэшом, используется тот факт, что большинство значений на выходе софтмакса "плохие", то есть очень близкие к нулю. Из-за ограниченной точности чисел с плавающей точкой, "плохие" значения ни на что не влияют при использовании в слагаемом нормализации. И для них авторы предлагают просто не считать градиенты. Вторая оптимизация такого рода — сортировка словаря по средним логитам, чтобы токены с "плохими" логитами попадали в один блок, и можно было такие блоки полностью пропускать.

По классификации в прошлом посте — это AG метод, полезен только при обучении. Есть и древние альтернативы, да хотя бы иерархический софтмакс или адаптивный софтмакс.

Экспериментально для Мистраля Немо удалось уменьшить память на лосс+градиенты с 8 Гб до 1.3 Гб, что лучше, чем в Liger Kernel. Аналогичная (и иногда даже более существенная) экономия памяти есть и для других моделей.

Потрогать можно через их библиотеку и патчинг модели. То есть вы делаете вот такое:


from cut_cross_entropy.transformers import cce_patch

model = ...
model = cce_patch(model)


После этого лосс и градиенты будут считаться как в статье. Но логиты не будут возвращаться, потому что они не материализуются в принципе.

BY Старший Авгур


Warning: Undefined variable $i in /var/www/group-telegram/post.php on line 260

Share with your friend now:
group-telegram.com/senior_augur/349

View MORE
Open in Telegram


Telegram | DID YOU KNOW?

Date: |

"Like the bombing of the maternity ward in Mariupol," he said, "Even before it hits the news, you see the videos on the Telegram channels." I want a secure messaging app, should I use Telegram? "This time we received the coordinates of enemy vehicles marked 'V' in Kyiv region," it added. The Russian invasion of Ukraine has been a driving force in markets for the past few weeks. False news often spreads via public groups, or chats, with potentially fatal effects.
from ar


Telegram Старший Авгур
FROM American