Sharded training — подход, при котором параметры, градиенты и состояния оптимизатора делятся на части и распределяются между устройствами для экономии памяти.
Определение
Sharded training — это метод распределённого обучения, который уменьшает объём памяти на каждом GPU путём разбиения (шардирования) параметров модели, градиентов и состояний оптимизатора. В отличие от классического data parallelism, где каждая копия модели полностью дублируется на всех устройствах, sharded training хранит только части данных, что позволяет обучать значительно более крупные модели.
Подход используется в DeepSpeed, Fully Sharded Data Parallel (FSDP) и других системах, обеспечивая масштабируемость при ограниченной памяти.
Как работает
Основная идея sharded training — разбить большие тензоры, связанные с обучением, на части и распределить их между устройствами. Во время forward и backward pass недостающие части собираются по запросу, а по завершении вычислений отправляются обратно.
Типичные элементы механизма:
- Шардирование параметров — каждый GPU хранит только свой фрагмент весов.
- Шардирование градиентов — градиенты распределяются после вычислений, исключая их дублирование.
- Шардирование состояний оптимизатора — momentum, variance и другие состояния также делятся между устройствами.
- On-demand all-gather — GPU собирает нужные фрагменты параметров только перед вычислением.
- Reduce-scatter — после backward градиенты собираются и распределяются по устройствам.
Таким образом, каждое устройство хранит лишь небольшой процент полной модели, но участвует в совместных вычислениях. При этом коммуникационные операции становятся ключевым фактором производительности.
Где применяется
- Обучение моделей, не помещающихся в память одного GPU.
- Тренировка больших Transformer-моделей.
- DeepSpeed ZeRO-2 и ZeRO-3.
- FSDP в PyTorch.
- Тренинг архитектур MoE с большим количеством экспертов.
- Суперкомпьютерные распределённые пайплайны.
Практические примеры использования
В FSDP PyTorch шардированию подвергаются параметры каждого слоя. Перед выполнением слоя параметры собираются через all-gather, после чего выполняются вычисления, и параметры снова освобождаются. Такой подход позволяет тренировать модели с сотнями миллиардов параметров.
В DeepSpeed шардирование используется в ZeRO-2 и ZeRO-3: параметры, градиенты и состояния оптимизатора делятся на части. Это резко уменьшает memory footprint каждого GPU и даёт возможность обучать большие модели на ограниченном количестве устройств.
В больших MoE-системах sharded training применяется для распределения экспертов между GPU, снижая объём VRAM, необходимый для роутеров и FFN-экспертов.
В гибридных системах sharded training используется вместе с tensor parallelism и pipeline parallelism для максимальной масштабируемости и снижения пикового использования памяти.
Ключевые свойства sharded training
- Минимизация дубликатов — данные хранятся только один раз по кластеру.
- Гибкость — подходит как для слоёв, так и для матриц и состояний оптимизатора.
- Снижение VRAM-потребления — ключевое преимущество по сравнению с data parallelism.
- Интеграция с масштабированием — работает вместе с другими методами параллелизма.
Проблемы и ограничения
- Зависимость от коммуникаций — частые all-gather и reduce-scatter требуют NVLink или InfiniBand.
- Рост latency — из-за динамического сбора параметров.
- Сложность настройки — необходимо подбирать уровни шардирования и стратегию offloading.
- Сложное профилирование — трудно найти коммутационные узкие места.
- Чувствительность к размеру batch — маленькие batch увеличивают overhead.
Преимущества и ограничения
- Плюс: обучает модели, которые невозможно разместить в одном GPU.
- Плюс: снижает объём памяти на устройство.
- Плюс: уменьшает дублирование состояний оптимизатора.
- Плюс: хорошо работает совместно с другими видами параллелизма.
- Минус: требует дорогой сетевой инфраструктуры.
- Минус: увеличивает количество коммуникаций.
- Минус: усложняет пайплайн обучения.
- Минус: требует тщательного профилирования.
Связанные термины
- ZeRO optimization
- FSDP (Fully Sharded Data Parallel)
- Data parallelism
- Tensor parallelism
- Pipeline parallelism
- Distributed training
- All-reduce
- Offloading