AI智能摘要
GPT
这里是萌新AI,这篇文章介绍了 Pytorch 框架中的 FSDP(完全分段数据并行)原理与实现。核心思想是将模型参数、优化器参数和梯度分片存储到不同 GPU 上,每个 GPU 仅保留部分参数。为解决显存问题,引入 FSDP unit 将模型分块,执行时只加载当前块参数。通过 sharding 操作,将 unit 内参数展平后分配到各 GPU,显著降低显存占用。文章还提供了相关代码实践与参考资源。
URL
type
Post
status
Published
date
Jul 26, 2025
slug
pytorchDT3
summary
分布式训练
tags
深度学习
Pytorch
大模型
实用教程
category
深度学习
icon
password
此篇博客主要用于FSDP学习记录。因为Pytorch已经实现了该API,所以该篇博客仅仅使用Pytorch框架进行。(注意:Pytorch的FSDP思想借鉴了DeepSpeed的ZeRO Stage3思想。DeepSpeed的内容不在该博客涉及,如果想了解DeepSpeed,请查阅DeepSpeed相关博客!)
📝 FSDP理论
FSDP
FSDP全称FullyShardedDataParallel,即完全分段数据并行。FSDP主要是将模型参数、优化器参数和梯度进行了分片处理。每个GPU依然复制一个模型框架,但仅仅保留相对应层的部分参数,清空其他参数的缓存。
在前向和反向传播时,正常的分布式方法(以DDP为例,不考虑all reduce操作)一次性传递,没有停顿。但是,FSDP的每个GPU仅仅存储了部分参数,在前向和反向传播时,FSDP需要all gather操作。如果all gather操作将所有GPU的参数一次性gather,那和将完整模型加载到GPU没有区别,并没有起到降低显存占用的效果。那么,如何解决这个问题?那就是将模型分块,即FSDP unit。执行某个FSDP unit时,便加载该unit的参数,不加载其他unit的参数。这样模型在前向和反向传播时,模型不会加载完整的参数。
FSDP unit
FSDP unit包含模型的几个layer,这几个layer可以是连续的,也可以是不连续的。如下图,将该模型分成三个FSDP unit。FSDP unit0包含0和3层,FSDP unit1包含1和2层,FSDP unit2包含4和5层。

FSDP unit仅仅将模型分为不同的块,那么接下来进行模型参数、优化器参数和梯度进行了分片处理,即sharding操作。
sharding
sharding将一个FSDP unit中的模型参数展平成一维后,分别储存在不同的GPU中。假设我们有2个GPU,模型weight的shape=(2,2)。weight展平后的shape=(4,)。那么,将展平后的weight维度:4除以节点(node)的GPU个数,即每个GPU的shape=(2,)。可以通过下图进行辅助理解:

那么,在前向传播时,FSDP会执行all gather操作,即先gather后broadcast。如下图:

前向传播结束后,FSDP unit只会保留原有的参数,通过gather获取的参数会被释放,恢复到原来的状态。
每个GPU输入的数据是不同的,这也导致每个GPU计算出的梯度也是不同的。如果每个GPU各自更新,那么这便会造成整个模型梯度混乱。为什么会发生梯度混乱呢?因为每个GPU只负责一部分参数,则更新时需要将FSDP unit的参数先gather接着更新然后broadcast。这样会造成GPU更新混乱问题,当某个GPU更新后,下一个GPU是在上一个GPU更新后接着更新。
针对上诉问题,FSDP采取了与DDP相同的处理,进行了reduce scatter操作。首先将各个GPU的梯度进行reduce(即求均值),然后将梯度拆分(scatter)并划分到对应的GPU,最后参数更新。可以通过下图辅助理解:

反向传播与前向传播操作相同,也需要all gather进行。以上便是FSDP的前向传播、反向传播和梯度更新的主要理论知识
FSDP并行
其实,FSDP的并行操作不仅仅包含数据并行还包括GPU间的通信与模型前向传播、反向传播和梯度计算的并行。可以简单理解为:FSDP unit0进行前向传播,FSDP unit1进行all gather操作等等。这便是FSDP并行的思想。
从FSDP并行可以发现,FSDP比较适合大模型训练,即模型参数需要多张卡才能完整存放。当在小模型中使用FSDP,会得不偿失。因为GPU间的通信相较于前向传播、后向传播以及梯度计算较为费时,所以小模型并不建议采用FSDP进行训练。
🤗 FSDP代码实践
📎 参考文章
以上便是FSDP方法学习记录,欢迎您在底部评论区留言,一起交流~
- 作者:不爱吃香菜的萌新
- 链接:https://hexo.levsongsw.com//deeplearn/pytorchDT3
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。


![[pytorch distributed] 从 DDP、模型并行、流水线并行到 FSDP(NCCL,deepspeed 与 Accelerate)_哔哩哔哩_bilibili](https://www.notion.so/image/https%3A%2F%2Fi2.hdslb.com%2Fbfs%2Farchive%2F014150ff43de24d8cb35084d9a11356e31733b4a.jpg%40100w_100h_1c.png%4057w_57h_1c.png?table=block&id=24602083-7c21-809d-9444-cf1e4e51c8cb&t=24602083-7c21-809d-9444-cf1e4e51c8cb)
![[pytorch distributed] 从 DDP、模型并行、流水线并行到 FSDP(NCCL,deepspeed 与 Accelerate)_哔哩哔哩_bilibili](https://www.notion.so/image/https%3A%2F%2Fi2.hdslb.com%2Fbfs%2Farchive%2F014150ff43de24d8cb35084d9a11356e31733b4a.jpg%40100w_100h_1c.png?table=block&id=24602083-7c21-809d-9444-cf1e4e51c8cb&t=24602083-7c21-809d-9444-cf1e4e51c8cb)




