前言
用好的風格寫程式,不僅讓別人的生活更加美麗,也是對未來的自己仁慈 XD
今天要分享一下 pytorch-styleguide 這個 github 上的 repository,並記錄我自己的心得,希望可以吸收消化裡面值得參考的部分。
命名
保持命名的 consistency 真的滿重要的,雖然很基本不過還是會有一些管理不佳的程式碼會存在 naming convention 不同的問題。即便不影響功能,但對於管理跟未來閱讀程式碼都是增加不必要的認知負擔,所以有意識地提醒自己做好這一塊還是滿重要的。
檔案架構
一般來說,當在實作一些比較表現較佳或是較進階的 model,通常都會需要實作自己的一些 neural network、custom loss,甚至會需要實作額外的 C++ extension function(例如寫出更有效率的 GPU code),所以怎麼保持 project 架構的乾淨就很重要。
在檔案分配上,我們會希望盡量模組化、讓各檔案保持單純,所以常見的做法是有幾個基本的檔案:
- xxx_networks.py:實作某個 neural network
- layers.py:實作一些基本的 neural network block,以供 xxx_networks.py 使用
- losses.py:實作 loss functions
- ops.py:實作自己客製化的 operations
- dataset.py:實作自己的 dataset class(我是習慣把建立 dataloader 的 function 也放在這裡面)
最後可能會再用一個 train.py 去使用 xxx_networks.py 跟 losses.py 的 class 來寫出 training loop。
接下來舉點例子:
- layers.py
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self):
super(ConvBlock, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(...),
nn.ReLU(),
nn.BatchNorm2d(...)
)
def forward(self, x):
return self.block(x)
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(...)
def build_conv_block(self, ...):
conv_block = []
conv_block += [nn.Conv2d(...),
norm_layer(...),
nn.ReLU()]
if use_dropout:
conv_block += [nn.Dropout(...)]
conv_block += [nn.Conv2d(...),
norm_layer(...)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
- simple_network.py
import torch.nn as nn
from layers import (
ConvBlock,
ResnetBlock
)
class SimpleNetwork(nn.Module):
def __init__(self, num_resnet_blocks=6):
super(SimpleNetwork, self).__init__()
layers = [ConvBlock(...)]
for i in range(num_resnet_blocks):
layers += [ResBlock(...)]
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
- losses.py
import torch.nn as nn
class CustomLoss(nn.Module):
def __init__(self):
super(CustomLoss, self).__init__()
def forward(self,x,y):
loss = torch.mean((x - y)**2)
return loss
- train.py
# import statements
import torch
import torch.nn as nn
from itertools import islice
from losses import CustomLoss
from simple_network import SimpleNetwork
from torch.nn.parallel import DistributedDataParallel as DDP
if __name__ == '__main__':
# Parse arguments
parser = argparse.ArgumentParser()
opt = parser.parse_args()
...
# Setup training data
train_dataset = ...
train_data_loader = data.DataLoader(train_dataset, ...)
test_dataset = ...
test_data_loader = data.DataLoader(test_dataset ...)
...
# Instantiate network
net = SimpleNetwork(...)
# Create losses (criterion in pytorch)
criterion = CustomLoss()
# If running on GPU(可以增進 model 訓練速度)
use_cuda = torch.cuda.is_available()
if use_cuda:
net = net.cuda()
optim = torch.optim.Adam(net.parameters(), lr=opt.lr)
# load checkpoint if needed/ wanted
start_n_iter = 0
start_epoch = 0
if opt.resume:
ckpt = load_checkpoint(opt.path_to_checkpoint) # custom method for loading last checkpoint
net.load_state_dict(ckpt['net'])
start_epoch = ckpt['epoch']
start_n_iter = ckpt['n_iter']
optim.load_state_dict(ckpt['optim'])
print("last checkpoint restored")
...
# Run on multiple node/GPU (PyTorch 的 Distributed Data Parallel 是滿好用的工具,可以做到 multi-node、multi-GPU 的 training,在業界上應該是很常需要使用,畢竟這可以大幅縮短訓練時間,代價是要用到更多機器)
net = DDP(net)
...
# Start the main loop
n_iter = start_n_iter
for epoch in range(start_epoch, opt.epochs):
# Set models to train mode
net.train()
for data in islice(train_data_loader, files_in_an_epoch):
img, label = data
if use_cuda:
img = img.cuda()
label = label.cuda()
...
# Forward pass
output = net(img)
# Calculate loss
loss = criterion(output, label)
# Backward pass
optim.zero_grad()
loss.backward()
optim.step()
# Do a test pass every x epochs and save checkpoint
if epoch % x == x-1:
net.eval()
# 用 torch.no_grad 可以省下不少 memory usage
with torch.no_grad:
# Do tests using data from test_dataloader
# Save checkpoint
...
一些注意事項
- 假設你今天使用 multi-node training,PyTorch 的 DDP 是 synchronize gradients,所以如果你的 model 裡面有 batch norm layer,那要注意 running mean/variance 在各個 node 上面可能會不同(因為通常我們會做 dataset sharding,讓各個 node 看到的),這可能會使得 training 時有些許的 instability。
- 為了保持 code 的乾淨簡單,上面的 train.py 範例我移掉了 tqdm、tensorboard 等東西,但實際應用時,這些都滿方便的。
總結
今天簡單分享了一些撰寫 PyTorch 程式碼的 tips,這份 guide 裡面還有很多小 tips 我沒有一一寫進來,但其實也沒必要寫,因為各 project 不同的細節還有太多。當你自己投入去做一個頗具規模的 PyTorch project 時,你會發現 PyTorch Forum 跟 PyTorch github repo 裡面的不少討論串是你的好夥伴。
另外網路上有不少的 open source project 也都有類似的架構,有興趣的讀者不妨去看一些自己感興趣的 PyTorch project 學習高手們的寫法。