TensorFlow2模型保存加载几个小坑

最近学习TensorFlow2,利用tf2构建深度学习推荐系统算法,模型保存和加载遇到了几个问题,深感无力,也未找到直接的解决方法,特此记录一下,希望博友们给点提示。

1.通用模型参数保存加载(没毛病)

利用tf2的save_weights和load_weights,该方法加载参数,需要提前构建深度学习网络。伪代码代码案例如下:
(1)自建神经网络

class Deep(Model):
    """
    DNN网络
    """
    def __init__(self, hiddens):
        """
        :param hiddens: 每个隐藏层神经元数
        """
        super(Deep, self).__init__()
        self.hiddens = hiddens
        self.deep_model = tf.keras.models.Sequential()
        for hidden in hiddens:
            self.deep_model.add(Dense(hidden, activation='relu', kernel_regularizer=tf.keras.regularizers.l2()))

    def call(self, inputs, training=None, mask=None):
        return self.deep_model(inputs)

(2)模型训练和保存模型

model = Deep([64, 64])
model.compile(loss='binary_crossentropy', metrics=['accuracy', 'AUC'])

checkpoint_save_path = 'model/deep_cross.ckpt'
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True)
model.fit(train_data, epochs=1, class_weight=class_weight, validation_data=dev_data, callbacks=cp_callback)

(3) 模型文件
在这里插入图片描述
(4)加载模型进行预测

model = Deep([64, 64])
model.compile(loss='binary_crossentropy', metrics=['accuracy', 'AUC'])

checkpoint_save_path = 'model/deep_cross.ckpt'
model.load_weights(checkpoint_save_path)

result = model.predict(dev_data)

2.h5模型应用(坑坑不断,暂时放弃中)

采用h5的目的是不想使用网络搭建那部分代码,只想直接获取模型进行预测。

bug 1、Unable to create link (name already exist)

出现该问题是因为构建自己的网络,利用循环添加w,b参数,未添加变量名,导致的问题,解决方法和问题代码和注释如下:

class Cross(Model):
    """
    cross网络
    """
    def __init__(self, embed_dim, layer_num):
        """
        :param embed_dim: 输入向量长度
        :param layer_num: 网络层数
        """
        super(Cross, self).__init__()
        self.embed_dim = embed_dim
        self.layer_num = layer_num
        self.ws = []
        self.bs = []
        # tf.Variable需要自己添加name参数,名称不能相同,不加即出现上述bug
        for i in range(layer_num):
            # 变量加了不同的name,否则保存h5模型会报错
            self.ws.append(tf.Variable(tf.random.truncated_normal(shape=(embed_dim, 1), stddev=0.01)), name='w' + str(i))
            self.bs.append(tf.Variable(tf.zeros(shape=(embed_dim, 1))), name='b' + str(i))

    def call(self, inputs, training=None, mask=None):
        x0 = inputs
        xl = x0
        for w, b in zip(self.ws, self.bs):
            xl_T = tf.reshape(xl, [-1, 1, self.embed_dim])
            xlTw = tf.tensordot(xl_T, w, axes=1)
            xl = x0 * xlTw + b + xl
        return xl

# 模型工具间和训练
model = Cross(2563)
model.compile(loss='binary_crossentropy', metrics=['accuracy', 'AUC'])

# 以h5为后缀,callbacks删除只保存weights
checkpoint_save_path = 'model/deep_cross.h5'
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path)
model.fit(train_data, epochs=1, class_weight=class_weight, validation_data=dev_data, callbacks=cp_callback)
bug 2、no model found in config file

接上面代码,利用tf.keras.models.load_model(‘model/deep_cross.h5’)加载h5模型直接报错。
(1)尝试重写get_config,无效

bug 3、NotImplementedError: Saving the model to HDF5

完整bug
NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn’t safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format=“tf”) or using save_weights

貌似和bug2是一个原因,不能保存自定义网络h5。不在训练过程中保存模型,训练完直接保存h5,直接报错。

model = Cross(2563)
model.compile(loss='binary_crossentropy', metrics=['accuracy', 'AUC'])

# 以h5为后缀,callbacks删除只保存weights
checkpoint_save_path = 'model/deep_cross.h5'
model.fit(train_data, epochs=1, class_weight=class_weight, validation_data=dev_data)

model.save(checkpoint_save_path)

综上

暂时还是好好用save_weights和load_weights吧,毕竟自定义网络,自己有代码的。

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值