Forward_Backward_Update耗时测试

Forward_Backward_Update耗时测试

本文主要是编写代码并测试在模型训练过程中,正向传播,反向传播,参数更新三个阶段的耗时比例。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch 
import time
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(1,6,5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)

def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),2)
x = F.max_pool2d(F.relu(self.conv2(x)),2)
x = x.view(-1,self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

def num_flat_features(self,x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features

if __name__ == '__main__':
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

inputs = torch.randn(256,1,32,32)
targets = torch.empty(256,dtype=torch.long).random_(10)

for k in range(10):
start_time = time.time()
outputs = model(inputs)
loss = criterion(outputs,targets)

print('Forward: {}s'.format(time.time()-start_time))

model.zero_grad()
start_time = time.time()
loss.backward()
print('Backward: {}s'.format(time.time()-start_time))

start_time = time.time()
optimizer.step()
print('Update: {}s'.format(time.time()-start_time))
print('========================')
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
Forward: 0.035942792892456055s
Backward: 0.035076141357421875s
Update: 0.0007233619689941406s
========================
Forward: 0.025434494018554688s
Backward: 0.03362894058227539s
Update: 0.00023508071899414062s
========================
Forward: 0.032956838607788086s
Backward: 0.0335536003112793s
Update: 0.0002675056457519531s
========================
Forward: 0.02550196647644043s
Backward: 0.04209303855895996s
Update: 0.00024390220642089844s
========================
Forward: 0.02639007568359375s
Backward: 0.033740997314453125s
Update: 0.001081705093383789s
========================
Forward: 0.026480913162231445s
Backward: 0.033477067947387695s
Update: 0.0007789134979248047s
========================
Forward: 0.02574777603149414s
Backward: 0.033417463302612305s
Update: 0.000713348388671875s
========================
Forward: 0.027414798736572266s
Backward: 0.03293299674987793s
Update: 0.00022792816162109375s
========================
Forward: 0.02549004554748535s
Backward: 0.03346991539001465s
Update: 0.00023126602172851562s
========================
Forward: 0.025561094284057617s
Backward: 0.033406972885131836s
Update: 0.00024080276489257812s
========================

根据上面实验我们可以得出 $Backward > Forward >> Update$

0%