TensorFlow Object detection API 教程系列:
- TensorFlow Object detection API 教程之一:Object detection API安装
- TensorFlow Object detection API 教程之二:训练自己的模型
- TensorFlow Object detection API 教程之三:测试自己的模型
在这一节,我们将要测试我们自己的模型,看一看训练的模型能否达到我们预期的效果。
将ckpt模型文件保存为pb模型文件
首先我们需要导出计算图(Inference Graph),在models/research/object_detection/
目录中,官方提供的export_inference_graph.py
脚本可以帮助我们轻松地去完成该操作。
找到一个想要导出pb文件的checkpoint,在models/research/object_detection/
路径下执行命令 :
1 | python3 export_inference_graph.py \ |
- input_type:保持模型,不用修改。
- pipeline_config_path:神经网络的参数设置文件路径,格式如上。
- trained_checkpoint_prefix:训练后最大步长的ckpt文件的目录,格式如上。
- output_directory:输入文件目录
如执行以上命令时报错为:no module named 'nets'
,进入models/research/
路径下执行:
1 | # From tensorflow/models/research/ |
读取pb模型文件
读取路径:
1
2
3
4
5
6ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = ROOT_PATH + '/include/hand_inference_graph/frozen_inference_graph.pb'
# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = ROOT_PATH + '/include/hand_inference_graph/hand_label_map.pbtxt'
NUM_CLASSES = 1加载模型:
1
2
3
4
5
6
7
8# Loading the model
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')加载标签:
1
2
3
4# Loading label map
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)读入图片:
1
2
3
4
5
6
7
8
9# For the sake of simplicity we will use only 2 images:
# image1.jpg
# image2.jpg
# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.
PATH_TO_TEST_IMAGES_DIR = 'test_images'
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]
# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8)
检测示例
完整代码
参考:
[1]https://pythonprogramming.net/testing-custom-object-detector-tensorflow-object-detection-api-tutorial/?completed=/training-custom-objects-tensorflow-object-detection-api-tutorial/
[2]https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb