The value of knowledge lies not in possession, but in share.

0%

TensorFlow Object detection API教程之三:测试自己的模型

TensorFlow Object detection API 教程系列:

在这一节,我们将要测试我们自己的模型,看一看训练的模型能否达到我们预期的效果。

将ckpt模型文件保存为pb模型文件

首先我们需要导出计算图(Inference Graph),在models/research/object_detection/目录中,官方提供的export_inference_graph.py脚本可以帮助我们轻松地去完成该操作。

找到一个想要导出pb文件的checkpoint,在models/research/object_detection/路径下执行命令 :

1
2
3
4
5
python3 export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path training/ssd_mobilenet_v1_pets.config \
--trained_checkpoint_prefix training/model.ckpt-10856 \
--output_directory mac_n_cheese_inference_graph
  • input_type:保持模型,不用修改。
  • pipeline_config_path:神经网络的参数设置文件路径,格式如上。
  • trained_checkpoint_prefix:训练后最大步长的ckpt文件的目录,格式如上。
  • output_directory:输入文件目录

如执行以上命令时报错为:no module named 'nets',进入models/research/路径下执行:

1
2
# From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

读取pb模型文件

  1. 读取路径:

    1
    2
    3
    4
    5
    6
    ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
    # Path to frozen detection graph. This is the actual model that is used for the object detection.
    PATH_TO_CKPT = ROOT_PATH + '/include/hand_inference_graph/frozen_inference_graph.pb'
    # List of the strings that is used to add correct label for each box.
    PATH_TO_LABELS = ROOT_PATH + '/include/hand_inference_graph/hand_label_map.pbtxt'
    NUM_CLASSES = 1
  2. 加载模型:

    1
    2
    3
    4
    5
    6
    7
    8
    # Loading the model
    detection_graph = tf.Graph()
    with detection_graph.as_default():
      od_graph_def = tf.GraphDef()
      with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')
  3. 加载标签:

    1
    2
    3
    4
    # Loading label map
    label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
    categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
    category_index = label_map_util.create_category_index(categories)
  4. 读入图片:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    # For the sake of simplicity we will use only 2 images:
    # image1.jpg
    # image2.jpg
    # If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.
    PATH_TO_TEST_IMAGES_DIR = 'test_images'
    TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]

    # Size, in inches, of the output images.
    IMAGE_SIZE = (12, 8)

检测示例

完整代码

点击查看官方代码


参考:
[1]https://pythonprogramming.net/testing-custom-object-detector-tensorflow-object-detection-api-tutorial/?completed=/training-custom-objects-tensorflow-object-detection-api-tutorial/
[2]https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb

🍭支持一根棒棒糖吧!