pytorch项目搭建

pytorch 搭建项目的基本流程

基本工作流程

  1. 相关工作调研: 评价指标、数据集、经典解决方案、待解决问题和已有方案的不同、精度和速度预估、相关难点 !

  2. 数据探索和方案确定

  3. 依次编写模型 models.py、数据集读取接口 datasets.py 、损失函数 losses.py 、评价指标 criterion.py
  4. 编写训练脚本(train.py)和测试脚本(test.py)
  5. 训练、调试和测评
  6. 模型的部署

注意,不要将所有层和模型放在同一个文件中。最佳做法是将最终网络分离为单独的文件(networks.py),并将layers 、loss 和 ops 保存在各自的文件(layers.py、losses.py、ops.py)中。完成的模型(由一个或多个网络组成)应在一个文件中引用,文件名为 yolov3.py、dcgan.py 这样。

(1) 构建神经网络

​ 自定义的网络继承自一般继承自 nn.Module 类, 必须有一个 forward 方法来实现各个层或操作的 forward 传递, 

对于具有单个输入单个输出的简单网络,请使用以下模式:

class ConvBlock(nn.Module):
  def __init__(self):
    super(ConvBlock, self).__init__()
    self.block = nn.Squential(
       nn.Conv2d(...),
       nn.ReLU(),
       nn.BatchNorm2d(...)
    )
   
  def forward(self, x):
    return self.block(x)

class SimpleNetwork(nn.Module):
    def __init__(self, num_of_layers = 15):
        super(SimpleNetwork, self).__init__()
        layers = list()
        for i in range(num_of_layers):
            layers.append(..)
        self.conv0 = nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv0(x)
        return out

我们建议将网络拆分为更小的可重用部分。网络由操作或其它网络模块组成。损失函数也是神经网络的模块,因此可以直接集成到网络中。

(2) 自定义数据集
class CustomDataset(Dataset):
    """ CustomDataset. """
    def __init__(self, root_dir='./data', transform=None):
        """
        """
        self.root_dir = root_dir
        self.transform = transform
        self.train_data = ...
        self.train_target = ...

    def __len__(self):
        return len(self.train_data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        data = Image.open(self.train_data[idx])
        target = Image.open(self.train_target[idx])

        if self.transform:
            data, target = self.transform(data, target)

        sample = {'data': data, 'high_img': target}
        return sample
(3) 自定义损失

​ 虽然 PyTorch 已经有很多标准的损失函数,但有时也可能需要创建自己的损失函数。为此,请创建单独的文件 losses.py 并扩展 nn.module 类以创建自定义的损失函数:

import torch
import torch.nn as nn

class CustomLoss(nn.Module):
    def __init__(self):
        """ CustomLoss"""
        super(CustomLoss, self).__init__()

    def forward(self, x, y):
        return torch.mean(torch.square(x  - y))
(4) 推荐可以参考的用于训练模型的代码结构
# import statements
import torch
import torch.nn as nn
from torch.utils import data
...

# set flags / seeds
torch.backends.cudnn.benchmark = True
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)
...
  
# dataset
transform_train = ...
trainform_text = ...

train_dataset = CustomDataset(args.train_dataset, is_trainval = True, transform = transform_train) 
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
                                           shuffle=True, num_workers=0, drop_last=False) 
valid_dataset = CustomDataset(args.valid_dataset, is_trainval = True, transform = transform_test)  
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.val_batch_size, 
                                           shuffle=True, num_workers=0) 
# model & loss
net = CustomNet().to(device) 
criterion = ...  
# lr & optimizer
optim = optim.SGD(model.parameters(), lr=args.init_lr, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 70], gamma=0.1)


# load resume
if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_prec = checkpoint['best_prec']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec: {:f}"
              .format(args.resume, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

def train(epoch):
    model.train() # 在 model(x) 前需要添加 model.eval() 或者 model.eval()

    avg_loss = 0.0
    train_acc = 0.0
    for batch_idx, batchdata in enumerate(train_loader):
        data, target = batchdata["data"], batchdata["target"] #
        data, target = data.to(device), target.to(device)  #
        # 在 loss.backward() 前用 optimizer.zero_grad() 清除累积梯度
        optimizer.zero_grad() # optimizer.zero_grad 与 model.zero_grad效果一样

        predict = model(data) # 
        loss = criterion(predict, target) #
        avg_loss += loss.item() #

        loss.backward()
        optimizer.step()

        print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

    if (epoch + 1) %  args.save_interval == 0:
        state = { 'epoch': epoch + 1,
                   'state_dict': model.state_dict(),
                   'best_prec': 0.0,
                   'optimizer': optimizer.state_dict()}
        model_path = os.path.join(args.checkpoint_dir, 'model_' + str(epoch) + '.pth')
        torch.save(state, model_path)


def test():
    model.eval()

    test_loss = 0
    for batch_idx, batchdata in enumerate(valid_loader):
        data, target = batchdata["data"], batchdata["target"] #
        data, target = data.to(device), target.to(device) #
        predict = model(data) # 
        test_loss += criterion(predict, target) #
        psnr = criterion(predict * 255, target * 255) #

    test_loss /= len(valid_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, loss:{}, PSNR: ({:.1f})\n'.format(
        test_loss, test_loss / len(valid_loader.dataset), psnr / len(valid_loader.dataset)))
    return psnr / float(len(valid_loader.dataset))


best_prec = 0.0
for epoch in range(args.start_epoch, args.epochs):
    train(epoch)
    scheduler.step()
    print(print(optimizer.state_dict()['param_groups'][0]['lr']))

    current_prec = test() 
    is_best = current_prec > best_prec # 更改大小写 !
    best_prec = max(best_prec, best_prec) #  max or min

    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec': best_prec,
        'optimizer': optimizer.state_dict(),
    }, is_best, args.checkpoint_dir)

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!