PyTorch Distributed Training
With the trend of AI models growing bigger, distributed training in AI is necessary for those large models training tasks (VLM, LLM, Diffusion, etc.) or with limited GPU computational resources. In this article, I’ll provide a quick introduction to setting up and running distributed training with the PyTorch framework across multiple GPUs and multiple nodes. I choose PyTorch as it is designed to be a framework that's both easy to use and delivers performance at scale. Indeed, it has firmly established itself as the most widely adopted deep learning framework in the research community.
There are multiple training strategies in PyTorch Distributed:
- Distributed Data-Parallel (DDP): replicas of the full model on each GPU, data are distributed, and gradients are synchronized across GPUs. Each GPU has enough VRAM to handle at least a batch size of 1. If your model is too large for a GPU to handle a batch, then the 3 methods below can be considered.
- Fully Sharded Data-Parallel Training (FSDP2): Shards model parameters, gradients, and optimizer states across GPUs to save memory and train huge models.
- Tensor Parallel (TP): Splits computations inside layers across GPUs to handle very large layers.
- Pipeline Parallel (PP): Splits model layers into stages across GPUs and processes micro-batches in sequence like an assembly line.
Here, I will talk about DDP first as it is the simplest yet easiest to scale across multiple GPUs and nodes.
Distributed Data-Parallel (DDP)
-1.png)
This method clones the full model on each GPU, data is distributed to each process. Then the forward and backward passes are performed separately. Finally, gradients are synchronized across GPUs with all-reduce (supported by NCCL).
The NVIDIA Collective Communication Library (NCCL is: “standard communication routines for GPUs, implementing all-reduce, all-gather, reduce, broadcast, reduce-scatter, as well as any send/receive based communication pattern. It has been optimized to achieve high bandwidth on platforms using PCIe, NVLink, NVswitch, as well as networking using InfiniBand Verbs or TCP/IP sockets. NCCL supports an arbitrary number of GPUs installed in a single node or across multiple nodes, and can be used in either single- or multi-process (e.g., MPI) applications.” (https://github.com/NVIDIA/nccl).
Single GPU Training to DDP Training
Before writing distributed training scripts, I would recommend writing in single GPU training first, then converting it to DDP is much more easier. We have a simple training code:
1import torch2import torch.nn.functional as F3from torch.utils.data import Dataset, DataLoader45class TrainDataset(Dataset):6 def __init__(self, size):7 self.size = size8 self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]910 def __len__(self):11 return self.size1213 def __getitem__(self, index):14 return self.data[index]1516class Trainer:17 def __init__(18 self,19 model: torch.nn.Module,20 train_data: DataLoader,21 optimizer: torch.optim.Optimizer,22 gpu_id: int,23 save_every: int,24 ) -> None:25 self.gpu_id = gpu_id26 self.model = model.to(gpu_id)27 self.train_data = train_data28 self.optimizer = optimizer29 self.save_every = save_every3031 def _run_batch(self, source, targets):32 self.optimizer.zero_grad()33 output = self.model(source)34 loss = F.cross_entropy(output, targets)35 loss.backward()36 self.optimizer.step()3738 def _run_epoch(self, epoch):39 b_sz = len(next(iter(self.train_data))[0])40 print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")41 for source, targets in self.train_data:42 source = source.to(self.gpu_id)43 targets = targets.to(self.gpu_id)44 self._run_batch(source, targets)4546 def _save_checkpoint(self, epoch):47 # I recommend saving the current epoch and the optimizer as well to continue training if needed48 ckp = self.model.state_dict()49 PATH = "checkpoint.pt"50 torch.save(ckp, PATH)51 print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")5253 def train(self, max_epochs: int):54 for epoch in range(max_epochs):55 self._run_epoch(epoch)56 if epoch % self.save_every == 0:57 self._save_checkpoint(epoch)5859def main(device, total_epochs, save_every, batch_size):60 dataset = TrainDataset(2048)61 model = torch.nn.Linear(20, 1)62 optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)63 train_data = DataLoader(64 dataset,65 batch_size=batch_size,66 pin_memory=True,67 shuffle=True68 )69 trainer = Trainer(model, train_data, optimizer, device, save_every)70 trainer.train(total_epochs)7172if __name__ == "__main__":73 total_epochs = 5074 save_every = 575 batch_size = 3276 device = 077 main(device, total_epochs, save_every, batch_size)
For distributed training in PyTorch, there are some terms you need to remember:
world_size
: total number of processes the training task runs.rank
: the specific ID of each process.
To convert to DDP Training, some basic packages need to be added:
1import torch.multiprocessing as mp2from torch.utils.data.distributed import DistributedSampler3from torch.nn.parallel import DistributedDataParallel as DDP4from torch.distributed import init_process_group, destroy_process_group
First, for each process, we need to set up a group of processes so they can communicate with each other:
1def ddp_setup(rank, world_size):2 os.environ["MASTER_ADDR"] = "localhost" # Usally for standalone training3 os.environ["MASTER_PORT"] = "12355" # Any free port4 torch.cuda.set_device(rank)5 init_process_group(backend="nccl", rank=rank, world_size=world_size)
Here rank
is the current process ID, world_size
is the number of processes used (typically each process per GPU). Here, we define the backend as nccl
.
Also, in a distributed training program, there’s always a master process and the worker processes. The master process coordinates all the communication across all of our processes. We need to specify the master host (IP) and port.
After setting up the process group, we need to:
- Initialize the model with
DDP(model, device_ids=[...])
. NOTE that when saving the model, savemodel.module.state_dict()
instead of justmodel.state_dict()
:1# Setting DDP wrapper for mode2self.model = DDP(self.model, device_ids=[self.gpu_id])34# When saving model, use model.module5ckp = self.model.module.state_dict()6torch.save(ckp, PATH) - Add
DistributedSampler(dataset)
inDataLoader
initialization, changeshuffle=False
because the Distributed Sampler already did the job:1DataLoader(2 dataset,3 batch_size=batch_size,4 pin_memory=True,5 shuffle=False,6 sampler=DistributedSampler(dataset)7) - In the
run
function, runddp_init
, then calldestroy_process_group
to clean up after the training is done. Also, remember to replacedevice
withrank
as each clone of the model now runs on a different GPU:1def main(rank, world_size, save_every, total_epochs, batch_size):2 ddp_setup(rank, world_size)3 ...4 trainer = Trainer(model, train_data, optimizer, rank, save_every)5 trainer.train(total_epochs)6 destroy_process_group() - Use
mp.spawn(run, rank=rank, world_size=world_size)
to run distributed. Hereworld_size
is the number of processes, and usually each process runs on one GPU:1if __name__ == "__main__":2 total_epochs = 503 save_every = 54 batch_size = 325 # Typically each GPU run a process so we use world_size = number of GPUs6 world_size = torch.cuda.device_count()78 # Don't need to pass rank as mp.spawn will do it9 mp.spawn(main, args=(world_size, save_every, total_epochs, batch_size), nprocs=world_size)
The full file can be downloaded here: multigpu.py
Torchrun
Instead of manually initializing rank
and world_size
for each process, we can use torchrun
a built-in command for PyTorch to run distributed training:
1def ddp_setup():2 # Master address and port is not needed anymore, we will define in run command3 torch.cuda.set_device(rank)4 init_process_group(backend="nccl") # Just define backend, torchrun with automatically assign rank and world_size
The rank
and the world_size
is not needed to specify anymore so we will remove them:
1def main(save_every, total_epochs, batch_size):2 ddp_setup()3 local_rank = int(os.getenv("LOCAL_RANK")) # Here we get local rank in env variables4 ...5 trainer = Trainer(model, train_data, optimizer, local_rank, save_every)6 trainer.train(total_epochs)7 destroy_process_group()8
We will call local_rank
GPU ID as each process runs on each GPU. Then, instead of using mp.spawn
, we will run by torchrun
:
1if __name__ == "__main__":2 total_epochs = 503 save_every = 54 batch_size = 325 main(save_every, total_epochs, batch_size)
Run script:
1torchrun \2 --standalone \ # Standalone mode, run multi-processes on a single machine3 --nproc_per_node=<number_of_process_per_node> \ # Usually number of GPUs used4 multigpu_torchrun.py
The full file can be downloaded here: multigpu_torchrun.py
Multi-nodes
All the settings above still run on a single machine with multiple GPUs. Suppose you have multiple machines (nodes), each node has multiple GPUs and you want to fully exploit them. Then torchrun
does supports multi-node training. There are several ways to run on multiple machines:
- Clone the training code, and run each train script on each node manually.
- Use a workload manager like SLURM.
There are a few things you should make sure to run multi-node training:
- Make sure your nodes can communicate with each other via TCP with:
nc -vz <IP> <port>
- Use an identical training environment in every node.
Then, we’re ready to start. Here I will run manually on 2 machines that I have set up. In one node, I use the training script:
1torchrun \2 --nproc_per_node=1 \3 --nnodes=2 \4 --node_rank=0 \5 --rdzv_id=123 \6 --rdzv_backend=c10d \7 --rdzv_endpoint=127.0.0.1:34359 \8 multigpu_torchrun.py
nproc_per_node
: number of processes per node, each machine of mine only has 1 GPU :(nnodes
: number of nodesnode_rank
: ID of the current noderdzv_id
: rendezvous process ID. A rendezvous is the process where all participating processes find each other and communicate.rdzv_backend
: the backend use of the rendezvous process, typicallyc10d
.rdzv_endpoint
: the endpoint of the rendezvous process. Ensure this endpoint is reachable across all nodes. The IP I set here is127.0.0.1
because the rendezvous process is also run in this node.
In the other one, I use the training script:
1torchrun \2 --nproc_per_node=1 \3 --nnodes=2 \4 --node_rank=1 \5 --rdzv_id=123 \6 --rdzv_backend=c10d \7 --rdzv_endpoint=192.168.0.60:34359 \8 multigpu_torchrun.py
rdzv_endpoint
: the endpoint of the rendezvous process. This node is not own the rendezvous process, so the endpoint pass must be reachable. Because my 2 machines run in a local network so I can set the IP of that node.
Also, in multi-node training, the int(os.getenv("LOCAL_RANK"))
is the local rank of a process in a node, and int(os.getenv("RANK"))
is the global rank of that process across all nodes.
Multi-nodes Debugging
The training script in a multiple-node setting is almost the same as the single-node setting. However, debugging this is quite painful because it is mostly related to the environment and network. I have some suggestions when debugging this:
- Make sure every node uses the same training environment, I used PyTorch docker image
pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
. - Make sure your nodes can communicate with each other via TCP with:
nc -vz <IP> <port>
- Check for firewall and disable IPv6 if necessary.
- Export NCCL SOCKET
export NCCL_SOCKET_IFNAME**=<if-name>
,** get if-name byifconfig
and find Ethernet. - Use legacy rendezvous backend
etcd
.
Summary
Remember to:
- To create a training script with DDP, I recommend writing a version train with a single GPU first, then convert to distributed training.
- Import necessary PyTorch distributed packages.
- Initialize process group with
init_process_group
, I recommend usingtorchrun
so as not to specify rank and world size. - Wrap the training model with
DDP
. - Add
DistributedSampler
inDataLoader
to distribute data across GPUs. - Run the script with
torchrun
.
Besides the PyTorch distributed packages, there are some other frameworks that support distributed training with just a few parameters, like:
- PyTorch Lightning
- Huggingface + Accelerate
If you understand how distributed training works, I recommend using the framework above as it wraps almost all distributed training strategies (DDP, single-node, multi-node, FSDP, etc.). If you need any information or want to discuss more, feel free to contact me.
Reference
Image: https://www.youtube.com/watch?v=bwNtfxEDjGA
Suraj Subramanian, Distributed Data Parallel in PyTorch, https://www.youtube.com/watch?v=-K3bZYHYHEA&list=PL_lsbAsL_o2CSuhUhJIiW0IkdT5C2wGWj