深度学习的项目结构和开发规范
文件组织结构
推荐采用下列文件组织结构。
1 | ├── checkpoints/ |
其中:
- checkpoints/: 用于保存训练好的模型,可使程序在异常退出后仍能重新载入模型,恢复训练
- data/:数据相关操作,包括数据预处理、dataset实现等
- models/:模型定义,可以有多个模型,例如上面的AlexNet和ResNet34,一个模型对应一个文件
- models/lib/:构成模型的相关部件
- utils/:可能用到的工具函数,在本次实验中主要是封装了可视化工具
- config.py:配置文件,所有可配置的变量都集中在此,并提供默认值
- main.py:主文件,训练和测试程序的入口,可通过不同的命令来指定不同的操作和参数
- requirements.txt:程序依赖的第三方库
- README.md:提供程序的必要说明
参数传递
深度学习的模型有很多参数,常规的变量名方式的赋值很占行数,
这都是其次的,主要是没有高亮,在长篇的coding中,修改值会很不方便和吃力
利用argparse这个包可以完美解决这个问题
建议在config.py中定义此文件, 这样可以直接在运行时设置参数,也方便后面的调用
例如:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38import argparse
def parse_opt():
parser = argparse.ArgumentParser()
####### Original hyper-parameters #######
# Data input settings
parser.add_argument('--input_json', type=str, default='data/cocotalk.json',
help='path to the json file containing additional info and vocab')
# Model settings
parser.add_argument('--rnn_size', type=int, default=512,
help='size of the rnn in number of hidden nodes in each layer')
# feature manipulation
parser.add_argument('--norm_att_feat', type=int, default=0,
help='If normalize attention features')
# Optimization: General
parser.add_argument('--max_epochs', type=int, default=-1,
help='number of epochs')
# Sample related
parser.add_argument('--max_length', type=int, default=20,
help='Maximum length during sampling')
#Optimization: for the Language Model
parser.add_argument('--learning_rate', type=float, default=4e-4,
help='learning rate')
# Transformer
parser.add_argument('--noamopt', action='store_true',
help='')
# Evaluation/Checkpointing
parser.add_argument('--save_checkpoint_every', type=int, default=2500,
help='how often to save a model checkpoint (in iterations)?')
args = parser.parse_args()
# Check if args are valid
assert args.rnn_size > 0, "rnn_size should be greater than 0"
assert args.num_layers > 0, "num_layers should be greater than 0"
return args
init构造函数
将所有的之前声明的需要传递的参数整合到当前的类中,供内部的成员函数分享使用。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18class AttModel(CaptionModel):
def __init__(self, opt):
super(AttModel, self).__init__()
self.vocab_size = opt.vocab_size
self.input_encoding_size = opt.input_encoding_size
self.rnn_size = opt.rnn_size
self.num_layers = opt.num_layers
self.drop_prob_lm = opt.drop_prob_lm
self.seq_length = opt.max_length or opt.seq_length
self.fc_feat_size = opt.fc_feat_size
self.att_feat_size = opt.att_feat_size
self.att_hid_size = opt.att_hid_size
self.use_bn = opt.use_bn
self.ss_prob = opt.sampling_prob
self.gpn = True if opt.use_gpn == 1 else False
self.embed_dim = opt.embed_dim
self.GCN_dim = opt.gcn_dim