深度学习的项目结构和开发规范

深度学习的项目结构和开发规范

文件组织结构

推荐采用下列文件组织结构。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
├── checkpoints/
├── data/
│ ├── __init__.py
│ ├── dataset.py
│ └── get_data.sh
├── models/
│ ├──lib/
│ │ ├──__init__.py
│ │ └──graph_conv_unit.py
│ ├── __init__.py
│ ├── AlexNet.py
│ ├── BasicModule.py
│ └── ResNet34.py
└── utils/
│ ├── __init__.py
│ └── visualize.py
├── config.py
├── main.py
├── requirements.txt
├── README.md

其中:

  • checkpoints/: 用于保存训练好的模型,可使程序在异常退出后仍能重新载入模型,恢复训练
  • data/:数据相关操作,包括数据预处理、dataset实现等
  • models/:模型定义,可以有多个模型,例如上面的AlexNet和ResNet34,一个模型对应一个文件
  • models/lib/:构成模型的相关部件
  • utils/:可能用到的工具函数,在本次实验中主要是封装了可视化工具
  • config.py:配置文件,所有可配置的变量都集中在此,并提供默认值
  • main.py:主文件,训练和测试程序的入口,可通过不同的命令来指定不同的操作和参数
  • requirements.txt:程序依赖的第三方库
  • README.md:提供程序的必要说明

参数传递

深度学习的模型有很多参数,常规的变量名方式的赋值很占行数,
这都是其次的,主要是没有高亮,在长篇的coding中,修改值会很不方便和吃力
利用argparse这个包可以完美解决这个问题

建议在config.py中定义此文件, 这样可以直接在运行时设置参数,也方便后面的调用
例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import argparse

def parse_opt():
parser = argparse.ArgumentParser()
####### Original hyper-parameters #######
# Data input settings
parser.add_argument('--input_json', type=str, default='data/cocotalk.json',
help='path to the json file containing additional info and vocab')
# Model settings
parser.add_argument('--rnn_size', type=int, default=512,
help='size of the rnn in number of hidden nodes in each layer')
# feature manipulation
parser.add_argument('--norm_att_feat', type=int, default=0,
help='If normalize attention features')
# Optimization: General
parser.add_argument('--max_epochs', type=int, default=-1,
help='number of epochs')
# Sample related
parser.add_argument('--max_length', type=int, default=20,
help='Maximum length during sampling')

#Optimization: for the Language Model
parser.add_argument('--learning_rate', type=float, default=4e-4,
help='learning rate')
# Transformer
parser.add_argument('--noamopt', action='store_true',
help='')
# Evaluation/Checkpointing
parser.add_argument('--save_checkpoint_every', type=int, default=2500,
help='how often to save a model checkpoint (in iterations)?')

args = parser.parse_args()

# Check if args are valid
assert args.rnn_size > 0, "rnn_size should be greater than 0"
assert args.num_layers > 0, "num_layers should be greater than 0"

return args

init构造函数

将所有的之前声明的需要传递的参数整合到当前的类中,供内部的成员函数分享使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class AttModel(CaptionModel):
def __init__(self, opt):
super(AttModel, self).__init__()
self.vocab_size = opt.vocab_size
self.input_encoding_size = opt.input_encoding_size
self.rnn_size = opt.rnn_size
self.num_layers = opt.num_layers
self.drop_prob_lm = opt.drop_prob_lm
self.seq_length = opt.max_length or opt.seq_length
self.fc_feat_size = opt.fc_feat_size
self.att_feat_size = opt.att_feat_size
self.att_hid_size = opt.att_hid_size
self.use_bn = opt.use_bn
self.ss_prob = opt.sampling_prob

self.gpn = True if opt.use_gpn == 1 else False
self.embed_dim = opt.embed_dim
self.GCN_dim = opt.gcn_dim

0%