最近学习TensorFlow2,利用tf2构建深度学习推荐系统算法,模型保存和加载遇到了几个问题,深感无力,也未找到直接的解决方法,特此记录一下,希望博友们给点提示。
1.通用模型参数保存加载(没毛病)
利用tf2的save_weights和load_weights,该方法加载参数,需要提前构建深度学习网络。伪代码代码案例如下:
(1)自建神经网络
class Deep(Model):
"""
DNN网络
"""
def __init__(self, hiddens):
"""
:param hiddens: 每个隐藏层神经元数
"""
super(Deep, self).__init__()
self.hiddens = hiddens
self.deep_model = tf.keras.models.Sequential()
for hidden in hiddens:
self.deep_model.add(Dense(hidden, activation='relu', kernel_regularizer=tf.keras.regularizers.l2()))
def call(self, inputs, training=None, mask=None):
return self.deep_model(inputs)
(2)模型训练和保存模型
model = Deep([64, 64])
model.compile(loss='binary_crossentropy', metrics=['accuracy', 'AUC'])
checkpoint_save_path = 'model/deep_cross.ckpt'
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True)
model.fit(train_data, epochs=1, class_weight=class_weight, validation_data=dev_data, callbacks=cp_callback)
(3) 模型文件
(4)加载模型进行预测
model = Deep([64, 64])
model.compile(loss='binary_crossentropy', metrics=['accuracy', 'AUC'])
checkpoint_save_path = 'model/deep_cross.ckpt'
model.load_weights(checkpoint_save_path)
result = model.predict(dev_data)
2.h5模型应用(坑坑不断,暂时放弃中)
采用h5的目的是不想使用网络搭建那部分代码,只想直接获取模型进行预测。
bug 1、Unable to create link (name already exist)
出现该问题是因为构建自己的网络,利用循环添加w,b参数,未添加变量名,导致的问题,解决方法和问题代码和注释如下:
class Cross(Model):
"""
cross网络
"""
def __init__(self, embed_dim, layer_num):
"""
:param embed_dim: 输入向量长度
:param layer_num: 网络层数
"""
super(Cross, self).__init__()
self.embed_dim = embed_dim
self.layer_num = layer_num
self.ws = []
self.bs = []
# tf.Variable需要自己添加name参数,名称不能相同,不加即出现上述bug
for i in range(layer_num):
# 变量加了不同的name,否则保存h5模型会报错
self.ws.append(tf.Variable(tf.random.truncated_normal(shape=(embed_dim, 1), stddev=0.01)), name='w' + str(i))
self.bs.append(tf.Variable(tf.zeros(shape=(embed_dim, 1))), name='b' + str(i))
def call(self, inputs, training=None, mask=None):
x0 = inputs
xl = x0
for w, b in zip(self.ws, self.bs):
xl_T = tf.reshape(xl, [-1, 1, self.embed_dim])
xlTw = tf.tensordot(xl_T, w, axes=1)
xl = x0 * xlTw + b + xl
return xl
# 模型工具间和训练
model = Cross(256, 3)
model.compile(loss='binary_crossentropy', metrics=['accuracy', 'AUC'])
# 以h5为后缀,callbacks删除只保存weights
checkpoint_save_path = 'model/deep_cross.h5'
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path)
model.fit(train_data, epochs=1, class_weight=class_weight, validation_data=dev_data, callbacks=cp_callback)
bug 2、no model found in config file
接上面代码,利用tf.keras.models.load_model(‘model/deep_cross.h5’)加载h5模型直接报错。
(1)尝试重写get_config,无效
bug 3、NotImplementedError: Saving the model to HDF5
完整bug
NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn’t safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format=“tf”) or using save_weights
貌似和bug2是一个原因,不能保存自定义网络h5。不在训练过程中保存模型,训练完直接保存h5,直接报错。
model = Cross(256, 3)
model.compile(loss='binary_crossentropy', metrics=['accuracy', 'AUC'])
# 以h5为后缀,callbacks删除只保存weights
checkpoint_save_path = 'model/deep_cross.h5'
model.fit(train_data, epochs=1, class_weight=class_weight, validation_data=dev_data)
model.save(checkpoint_save_path)
综上
暂时还是好好用save_weights和load_weights吧,毕竟自定义网络,自己有代码的。