曹耘豪的博客

Keras 模型的保存与载入

  1. 保存
  2. 载入
  3. 保存自定义的层

Keras使用h5py进行模型的保存和载入

保存

Keras内部Model继承Container

调用Container.save方法进行保存模型

载入

使用如下代码载入模型

1
model = keras.models.load_model(model_path)

如果模型存在自定义的函数(有的激活函数Keras没有内置,如atan)和层Layer,则使用CustomObjectScope,具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
from keras.utils import CustomObjectScope

with CustomObjectScope({
'atan': tf.atan,
'MyLayer': MyLayer,
}):
model = keras.models.load_model(model_path)

# 如果需要载入的模型进行预测
# 生成预测函数,这应该是Kears的一个Bug
model._make_predict_function()

保存自定义的层

一般自定义的层会继承Layer,如果自定义的层存在额外的参数,如下代码:

1
2
3
4
5
class MyLayer(layers.Layer):
def __init__(self, arg1, arg2=2, **kwargs)
super().__init__(**kwargs)
self.arg1 = arg1
self.arg2 = arg2

在存在额外参数的情况下,如果需要对自定义的层进行保存,则需要重写get_config方法,重写后的代码如下:

1
2
3
4
5
6
7
def get_config(self):
config = super().get_config()
config.update({
'arg1': self.arg1,
'arg2': self.arg2,
})
return config

这样在模型保存的时候便可以把自定义层的参数保存下来啦,也可以正常载入~

   /