AI智能摘要
GPT
这里是萌新AI,这篇文章介绍了 PyTorch 框架中的 FSDP 分布式训练方法。FSDP 通过将模型参数、优化器参数和梯度进行分片处理,有效降低了显存占用。其核心在于将模型划分为多个 FSDP unit,每个 GPU 仅存储对应 unit 的参数,并在前向和反向传播时通过 all gather 操作按需加载,从而避免了完整模型的显存负担。
URL
type
status
date
slug
summary
tags
category
icon
password
此篇博客主要用于FSDP学习记录。因为Pytorch已经实现了该API,所以该篇博客仅仅使用Pytorch框架进行。(注意:Pytorch的FSDP思想借鉴了DeepSpeed的ZeRO Stage3思想。DeepSpeed的内容不在该博客涉及,如果想了解DeepSpeed,请查阅DeepSpeed相关博客!)
📝 FSDP理论
FSDP
FSDP全称FullyShardedDataParallel,即完全分段数据并行。FSDP主要是将模型参数、优化器参数和梯度进行了分片处理。每个GPU依然复制一个模型框架,但仅仅保留相对应层的部分参数,清空其他参数的缓存。
在前向和反向传播时,正常的分布式方法(以DDP为例,不考虑all reduce操作)一次性传递,没有停顿。但是,FSDP的每个GPU仅仅存储了部分参数,在前向和反向传播时,FSDP需要all gather操作。如果all gather操作将所有GPU的参数一次性gather,那和将完整模型加载到GPU没有区别,并没有起到降低显存占用的效果。那么,如何解决这个问题?那就是将模型分块,即FSDP unit。执行某个FSDP unit时,便加载该unit的参数,不加载其他unit的参数。这样模型在前向和反向传播时,模型不会加载完整的参数。
FSDP unit
FSDP unit包含模型的几个layer,这几个layer可以是连续的,也可以是不连续的。如下图,将该模型分成三个FSDP unit。FSDP unit0包含0和3层,FSDP unit1包含1和2层,FSDP unit2包含4和5层。

FSDP unit仅仅将模型分为不同的块,那么接下来进行模型参数、优化器参数和梯度进行了分片处理,即sharding操作。
sharding
sharding将一个FSDP unit中的模型参数展平成一维后,分别储存在不同的GPU中。假设我们有2个GPU,模型weight的shape=(2,2)。weight展平后的shape=(4,)。那么,将展平后的weight维度:4除以节点(node)的GPU个数,即每个GPU的shape=(2,)。可以通过下图进行辅助理解:

那么,在前向传播时,FSDP会执行all gather操作,即先gather后broadcast。如下图:

前向传播结束后,FSDP unit只会保留原有的参数,通过gather获取的参数会被释放,恢复到原来的状态。
每个GPU输入的数据是不同的,这也导致每个GPU计算出的梯度也是不同的。如果每个GPU各自更新,那么这便会造成整个模型梯度混乱。为什么会发生梯度混乱呢?因为每个GPU只负责一部分参数,则更新时需要将FSDP unit的参数先gather接着更新然后broadcast。这样会造成GPU更新混乱问题,当某个GPU更新后,下一个GPU是在上一个GPU更新后接着更新。
针对上诉问题,FSDP采取了与DDP相同的处理,进行了reduce scatter操作。首先将各个GPU的梯度进行reduce(即求均值),然后将梯度拆分(scatter)并划分到对应的GPU,最后参数更新。可以通过下图辅助理解:

反向传播与前向传播操作相同,也需要all gather进行。以上便是FSDP的前向传播、反向传播和梯度更新的主要理论知识
FSDP并行
其实,FSDP的并行操作不仅仅包含数据并行还包括GPU间的通信与模型前向传播、反向传播和梯度计算的并行。可以简单理解为:FSDP unit0进行前向传播,FSDP unit1进行all gather操作等等。这便是FSDP并行的思想。
从FSDP并行可以发现,FSDP比较适合大模型训练,即模型参数需要多张卡才能完整存放。当在小模型中使用FSDP,会得不偿失。因为GPU间的通信相较于前向传播、后向传播以及梯度计算较为费时,所以小模型并不建议采用FSDP进行训练。
🤗 FSDP代码实践
📎 参考文章
以上便是FSDP方法学习记录,欢迎您在底部评论区留言,一起交流~
- 作者:不爱吃香菜的萌新
- 链接:https://hexo.levsongsw.com//deeplearn/pytorchDT3
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。


![[pytorch distributed] 从 DDP、模型并行、流水线并行到 FSDP(NCCL,deepspeed 与 Accelerate)_哔哩哔哩_bilibili](https://www.notion.so/image/https%3A%2F%2Fi2.hdslb.com%2Fbfs%2Farchive%2F014150ff43de24d8cb35084d9a11356e31733b4a.jpg%40100w_100h_1c.png%4057w_57h_1c.png?table=block&id=24602083-7c21-809d-9444-cf1e4e51c8cb&t=24602083-7c21-809d-9444-cf1e4e51c8cb)
![[pytorch distributed] 从 DDP、模型并行、流水线并行到 FSDP(NCCL,deepspeed 与 Accelerate)_哔哩哔哩_bilibili](https://www.notion.so/image/https%3A%2F%2Fi2.hdslb.com%2Fbfs%2Farchive%2F014150ff43de24d8cb35084d9a11356e31733b4a.jpg%40100w_100h_1c.png?table=block&id=24602083-7c21-809d-9444-cf1e4e51c8cb&t=24602083-7c21-809d-9444-cf1e4e51c8cb)




