部分代码参考:https://blog.csdn.net/lai_cheng/article/details/90643100
剪枝:
剪枝就是利用某一个准则对某一组或某一个权值置0从而达到将网络神经元置0以达到稀疏化网络连接从而加快整个推理过程及缩小模型大小的迭代过程,这个准则有暴力穷尽组合排忧、使用对角 Hessian 逼近计算每个权值的重要性、基于一阶泰勒展开的模型代价函数来对权值排序、基于L1绝对值的权值参数大小进行排序、基于在小验证集上的影响进行分值分配排序等方法,而某一组或某一个网络权值则可以是整个卷积核、全连接层、卷积核或全连接层上的某个权重参数,剪枝的目的是将冗余的神经元参数置0减小模型大小(需要特殊的模型存储方式)减少计算参数(需要某种特殊的硬件计算方式)稀疏化网络连接加快推理速度。
模型剪枝方法:
model_pruning:模型训练时剪枝,只需选定需要剪枝的层,对于选中做剪枝的层增加一个二进制掩模(mask)变量,形状和该层的权值张量形状完全相同。该掩模决定了哪些权值参与前向计算。掩模更新算法则需要为 TensorFlow 训练计算图注入特殊运算符,对当前层权值按绝对值大小排序,对幅度小于一定门限的权值将其对应掩模值设为 0。反向传播梯度也经过掩模,被屏蔽的权值(mask 为 0)在反向传播步骤中无法获得更新量。在保存模型时则可以通过去掉剪枝Ops的方式直接稀疏化权重,这样就起到了稀疏连接的作用。
官方提供model_pruning例子:
tf.app.flags.DEFINE_string(
'pruning_hparams', '',
"""Comma separated list of pruning-related hyperparameters""")
with tf.graph.as_default():
# Create global step variable
global_step = tf.train.get_or_create_global_step()
# Parse pruning hyperparameters
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
# Create a pruning object using the pruning specification
p = pruning.Pruning(pruning_hparams, global_step=global_step)
# Add conditional mask update op. Executing this op will update all
# the masks in the graph if the current global step is in the range
# [begin_pruning_step, end_pruning_step] as specified by the pruning spec
mask_update_op = p.conditional_mask_update_op()
# Add summaries to keep track of the sparsity in different layers during training
p.add_pruning_summaries()
with tf.train.MonitoredTrainingSession(...) as mon_sess:
# Run the usual training op in the tf session
mon_sess.run(train_op)
# Update the masks by running the mask_update_op
mon_sess.run(mask_update_op)
一定要保证传给pruning的global_step是随着训练迭代保持增长的,否则不会产生剪枝效果!
全连接层剪枝:
from tensorflow.contrib.model_pruning.python.layers import layers
fc_layer1 = layers.masked_fully_connected(ft, 200)
fc_layer2 = layers.masked_fully_connected(fc_layer1, 100)
prediction = layers.masked_fully_connected(fc_layer2, 10)
卷积层剪枝: