Uploaded by Mohammad MR7

ECE 579 Homework 5 ChaojiZuo

advertisement
Rutgers University - ECE 579
Introduction to Deep Learning
Homework Assignment 5
Chaoji Zuo
chaoji.zuo@rutgers.edu
CZ296 / 190003416
Problem (Prune a LeNet-5). In this problem, you are asked to train and test a
neural network for LeNet-5 on MNIST dataset.
I tried different compression ration, and their performances are pretty great. The difference of accuracy is very small.
Figure 1: Compression Ratio 0.5
Figure 2: Compression Ratio 0.25
Core code here: I attach the whole code files in attachment.
1
2
3
4
5
6
c l a s s PruningWeight ( o b j e c t ) :
def
i n i t ( s e l f , r a t i o =0.0) :
s e l f . MaskList = [ ]
s e l f . threshold = [ ]
self . ratio = ratio
7
8
9
10
11
12
d e f I n i t ( s e l f , model ) :
f o r m i n model . modules ( ) :
i f i s i n s t a n c e (m, nn . L i n e a r ) o r i s i n s t a n c e (m, nn . Conv2d ) :
with t o r c h . n o g r a d ( ) :
m. w e i g h t . c o p y ( s e l f . SetUpPruning (m. weight , s e l f . r a t i o ) )
13
14
15
16
17
18
def
SetUpPruning ( s e l f , weight , r a t i o ) :
t h r e s h o l d = s e l f . FindMidValue ( weight , r a t i o )
s p a r s e w e i g h t , Mask = s e l f . I n i t M a s k ( weight , t h r e s h o l d )
s e l f . t h r e s h o l d . append ( t h r e s h o l d )
s e l f . MaskList . append ( Mask )
1
Chaoji Zuo- CZ296 / 190003416
return sparse weight
19
20
21
def
FindMidValue ( s e l f , weight , r a t i o ) :
f l a t t e n w e i g h t = t o r c h . f l a t t e n ( t o r c h . abs ( w e i g h t ) )
p r i n t ( f l a t t e n w e i g h t . shape )
sorted ,
= torch . sort ( flatten weight )
index = i n t ( r a t i o ∗ f l a t t e n w e i g h t . s i z e ( ) [ 0 ] )
threshold = sorted [ index ]
return threshold
def
I n i t M a s k ( s e l f , w, t h r e s h o l d ) :
mask = t o r c h . abs (w) . ge ( t h r e s h o l d ) . type ( dtype=t o r c h . f l o a t 3 2 )
w [ t o r c h . abs (w)<t h r e s h o l d ] =0.0
r e t u r n w, mask
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
d e f R e c o v e r S p a r s e ( s e l f , model ) :
idx = 0
f o r m i n model . modules ( ) :
i f i s i n s t a n c e (m, nn . L i n e a r ) o r i s i n s t a n c e (m, nn . Conv2d ) :
with t o r c h . n o g r a d ( ) :
m. w e i g h t . c o p y (m. w e i g h t ∗ s e l f . MaskList [ i d x ] )
i d x += 1
41
42
43
44
45
46
47
48
49
50
d e f T e s t S p a r s e ( s e l f , model ) :
zero cnt = 0
all cnt = 0
f o r m i n model . modules ( ) :
i f i s i n s t a n c e (m, nn . L i n e a r ) o r i s i n s t a n c e (m, nn . Conv2d ) :
w = m. w e i g h t
z e r o c n t += t o r c h . sum ( (w == 0 ) . i n t ( ) ) . item ( )
a l l c n t += w . nelement ( )
p r i n t ( ’ t h e s p a r s i t y o f model i s %.3 f ’ % ( 1 . 0 ∗ z e r o c n t / a l l c n t ) )
2
Download