pytorch项目搭建
pytorch 搭建项目的基本流程
基本工作流程
相关工作调研: 评价指标、数据集、经典解决方案、待解决问题和已有方案的不同、精度和速度预估、相关难点 !
数据探索和方案确定
- 依次编写模型 models.py、数据集读取接口 datasets.py 、损失函数 losses.py 、评价指标 criterion.py
- 编写训练脚本(train.py)和测试脚本(test.py)
- 训练、调试和测评
- 模型的部署
注意,不要将所有层和模型放在同一个文件中。最佳做法是将最终网络分离为单独的文件(networks.py),并将layers 、loss 和 ops 保存在各自的文件(layers.py、losses.py、ops.py)中。完成的模型(由一个或多个网络组成)应在一个文件中引用,文件名为 yolov3.py、dcgan.py 这样。
(1) 构建神经网络
自定义的网络继承自一般继承自 nn.Module
类, 必须有一个 forward
方法来实现各个层或操作的 forward 传递,
对于具有单个输入和单个输出的简单网络,请使用以下模式:
class ConvBlock(nn.Module):
def __init__(self):
super(ConvBlock, self).__init__()
self.block = nn.Squential(
nn.Conv2d(...),
nn.ReLU(),
nn.BatchNorm2d(...)
)
def forward(self, x):
return self.block(x)
class SimpleNetwork(nn.Module):
def __init__(self, num_of_layers = 15):
super(SimpleNetwork, self).__init__()
layers = list()
for i in range(num_of_layers):
layers.append(..)
self.conv0 = nn.Sequential(*layers)
def forward(self, x):
out = self.conv0(x)
return out
我们建议将网络拆分为更小的可重用部分。网络由操作或其它网络模块组成。损失函数也是神经网络的模块,因此可以直接集成到网络中。
(2) 自定义数据集
class CustomDataset(Dataset):
""" CustomDataset. """
def __init__(self, root_dir='./data', transform=None):
"""
"""
self.root_dir = root_dir
self.transform = transform
self.train_data = ...
self.train_target = ...
def __len__(self):
return len(self.train_data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
data = Image.open(self.train_data[idx])
target = Image.open(self.train_target[idx])
if self.transform:
data, target = self.transform(data, target)
sample = {'data': data, 'high_img': target}
return sample
(3) 自定义损失
虽然 PyTorch 已经有很多标准的损失函数,但有时也可能需要创建自己的损失函数。为此,请创建单独的文件 losses.py
并扩展 nn.module
类以创建自定义的损失函数:
import torch
import torch.nn as nn
class CustomLoss(nn.Module):
def __init__(self):
""" CustomLoss"""
super(CustomLoss, self).__init__()
def forward(self, x, y):
return torch.mean(torch.square(x - y))
(4) 推荐可以参考的用于训练模型的代码结构
# import statements
import torch
import torch.nn as nn
from torch.utils import data
...
# set flags / seeds
torch.backends.cudnn.benchmark = True
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)
...
# dataset
transform_train = ...
trainform_text = ...
train_dataset = CustomDataset(args.train_dataset, is_trainval = True, transform = transform_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=0, drop_last=False)
valid_dataset = CustomDataset(args.valid_dataset, is_trainval = True, transform = transform_test)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.val_batch_size,
shuffle=True, num_workers=0)
# model & loss
net = CustomNet().to(device)
criterion = ...
# lr & optimizer
optim = optim.SGD(model.parameters(), lr=args.init_lr, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 70], gamma=0.1)
# load resume
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec = checkpoint['best_prec']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {}) Prec: {:f}"
.format(args.resume, checkpoint['epoch'], best_prec1))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
def train(epoch):
model.train() # 在 model(x) 前需要添加 model.eval() 或者 model.eval()
avg_loss = 0.0
train_acc = 0.0
for batch_idx, batchdata in enumerate(train_loader):
data, target = batchdata["data"], batchdata["target"] #
data, target = data.to(device), target.to(device) #
# 在 loss.backward() 前用 optimizer.zero_grad() 清除累积梯度
optimizer.zero_grad() # optimizer.zero_grad 与 model.zero_grad效果一样
predict = model(data) #
loss = criterion(predict, target) #
avg_loss += loss.item() #
loss.backward()
optimizer.step()
print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if (epoch + 1) % args.save_interval == 0:
state = { 'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_prec': 0.0,
'optimizer': optimizer.state_dict()}
model_path = os.path.join(args.checkpoint_dir, 'model_' + str(epoch) + '.pth')
torch.save(state, model_path)
def test():
model.eval()
test_loss = 0
for batch_idx, batchdata in enumerate(valid_loader):
data, target = batchdata["data"], batchdata["target"] #
data, target = data.to(device), target.to(device) #
predict = model(data) #
test_loss += criterion(predict, target) #
psnr = criterion(predict * 255, target * 255) #
test_loss /= len(valid_loader.dataset)
print('\nTest set: Average loss: {:.4f}, loss:{}, PSNR: ({:.1f})\n'.format(
test_loss, test_loss / len(valid_loader.dataset), psnr / len(valid_loader.dataset)))
return psnr / float(len(valid_loader.dataset))
best_prec = 0.0
for epoch in range(args.start_epoch, args.epochs):
train(epoch)
scheduler.step()
print(print(optimizer.state_dict()['param_groups'][0]['lr']))
current_prec = test()
is_best = current_prec > best_prec # 更改大小写 !
best_prec = max(best_prec, best_prec) # max or min
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_prec': best_prec,
'optimizer': optimizer.state_dict(),
}, is_best, args.checkpoint_dir)
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!