pytorch 从指定epoch恢复训练
2024-04-09 20:30:34  阅读数 735

1、保存模型

保存整个模型

torch.save(net, path)

保存权重

state_dict = net.state_dict()

torch.save(state_dict , path)

2、模型训练过程保存

checkpoint = {

        "net": model.state_dict(),

        'optimizer':optimizer.state_dict(),

        "epoch": epoch

    }

3、指定epoch恢复

path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" # 断点路径

checkpoint = torch.load(path_checkpoint)  # 加载断点

model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数

start_epoch = checkpoint['epoch']  # 设置开始的epoch