FSDP: Experiences on Scaling Fully Sharded Data Parallel

一、论文概览

属性内容
标题PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel
arXiv2304.11277
机构Meta (PyTorch Team)
代码PyTorch Distributed (内置)

核心贡献

  1. PyTorch 原生 ZeRO-3:FSDP 将 ZeRO-3 思想集成到 PyTorch 生态
  2. 混合分片策略:支持 full shard / shard-grad-op / no-shard 三种模式
  3. 反向预取(Backward Prefetch):反向中预取下一层参数,隐藏通信延迟
  4. 生产级验证:在 Meta 的 512+ GPU 集群上验证

二、技术方法

分片策略

模式分片对象适用场景
FULL_SHARD参数+梯度+优化器大模型,高显存瓶颈
SHARD_GRAD_OP梯度+优化器中等模型,适度通信
NO_SHARD不切片(等同 DDP)小模型,最优化通信

反向预取

在反向传播中,fsdp_pre_backward_hook 提前预取下一层的参数 AllGather,与当前层的计算 overlap。相比 ZeRO-3 的实现,反向预取约减少 20% 的通信等待时间。


三、实验

在 Meta 的生产集群上验证:

  • 174B 参数模型在 256 A100 上训练
  • 相比 DDP 把模型容量从 41B 提升到 174B
  • 与 ZeRO-3 性能相当,零额外代码依赖

四、个人评价

FSDP 的重要意义在于将 ZeRO-3 的显存优化带入 PyTorch 主流生态,向所有 PyTorch 用户开放了大规模训练能力。它在训练小模型时引入的开销比 ZeRO 更可控(通过混合分片策略),但在极致大模型场景下,Megatron-LM 的 TP+PP 仍不可替代。

相关链接