Keras模型保存和加载的两种方式

Keras模型保存和加载的两种方式

方式一: 保存所有状态

保存模型和模型图

# 保存模型 model.save(file_path)
model_name = '{}/{}_{}_{}_v2.h5'.format(params['model_dir'],params['filters'],params['pool_size_1'],params['pool_size_2'])
model.save(model_name)

# 保存模型图
from keras.utils import plot_model
# 需要安装pip install pydot
model_plot = '{}/{}_{}_{}_v2.png'.format(params['model_dir'],params['filters'],params['pool_size_1'],params['pool_size_2'])
plot_model(model, to_file=model_plot)
  • 模型图如图所示

保存的模型图

加载模型

from keras.models import load_model

model_path = '../docs/keras/100_2_3_v2.h5'
model = load_model(model_path)

优势和弊端

  • 优势一在于模型保存和加载就一行代码,写起来很方便。
  • 优势二在于不仅保存了模型的结构和参数,也保存了训练配置等信息。以便于从上次训练中断的地方继续训练优化。
  • 劣势就是占空间太大,我的模型用这种方式占了一个G。【红色部分就是上述模型采用第一种方式保存的文件】本地使用还好,如果是多人的模块需要集成,上传或者同步将会很耗时。
    这里写图片描述

方式二: 只保存模型结构和模型参数

保存模型

保存模型图部分和方式一相同。

import yaml
import json

# 保存模型结构到yaml文件或者json文件
yaml_string = model.to_yaml()
open('../docs/keras/model_architecture.yaml', 'w').write(yaml_string)
# json_string = model.to_json()
# open('../docs/keras/model_architecture.json', 'w').write(json_string)

# 保存模型参数到h5文件
model.save_weights('../docs/keras/model_weights.h5')

加载模型

import yaml
import json
from keras.models import model_from_json
from keras.models import model_from_yaml

# 加载模型结构
model = model_from_yaml(open('../docs/keras/model_architecture.yaml').read())
# model = model_from_json(open('../docs/keras/model_architecture.json').read())

# 加载模型参数
model.load_weights('../docs/keras/model_weights.h5')

优势和弊端

  • 优势就是节省了硬盘空间,方便同步和协作
  • 劣势是丢失了训练的一些配置信息
    这里写图片描述

大家按照自己的需求选择合适的方法~

参考

我通过web代码实时加载模型进行预测,但报如下错误 Traceback (most recent call last): File "/root/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1997, in __call__ return self.wsgi_app(environ, start_response) File "/root/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1985, in wsgi_app response = self.handle_exception(e) File "/root/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1540, in handle_exception reraise(exc_type, exc_value, tb) File "/root/anaconda3/lib/python3.6/site-packages/flask/_compat.py", line 33, in reraise raise value File "/root/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1982, in wsgi_app response = self.full_dispatch_request() File "/root/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1614, in full_dispatch_request rv = self.handle_user_exception(e) File "/root/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1517, in handle_user_exception reraise(exc_type, exc_value, tb) File "/root/anaconda3/lib/python3.6/site-packages/flask/_compat.py", line 33, in reraise raise value File "/root/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1612, in full_dispatch_request rv = self.dispatch_request() File "/root/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1598, in dispatch_request return self.view_functions[rule.endpoint](**req.view_args) File "/root/anaconda3/code/App.py", line 41, in predict model=load_model(root_path+model_name) File "/root/anaconda3/lib/python3.6/site-packages/keras/models.py", line 249, in load_model topology.load_weights_from_hdf5_group(f['model_weights'], model.layers) File "/root/anaconda3/lib/python3.6/site-packages/keras/engine/topology.py", line 3008, in load_weights_from_hdf5_group K.batch_set_value(weight_value_tuples) File "/root/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2189, in batch_set_value get_session().run(assign_ops, feed_dict=feed_dict) File "/root/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 895, in run run_metadata_ptr) File "/root/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1071, in _run + e.args[0]) TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(1, 16), dtype=float32) is not an element of this graph.
©️2020 CSDN 皮肤主题: Age of Ai 设计师:meimeiellie 返回首页