TensorFlow Object detection API 教程系列:
- TensorFlow Object detection API 教程之一:Object detection API安装
- TensorFlow Object detection API 教程之二:训练自己的模型
- TensorFlow Object detection API 教程之三:测试自己的模型
这一节,我们将对TensorFlow中的训练过程做一个介绍,训练模型的步骤可大体划分为以下几步:
收集数据
去收集至少100张包含你需要检测目标的图像,理想的情况是数据越多越好,不过对下一步的打标签带来沉重的任务。
将数据安装9:1的比例划分为训练集和测试集,并根据要训练的数据集,创建.pbtxt文件。
打标签
使用LabelImg对数据集打标签,可以生成Pascal VOC格式的xml文件。
关于LabelImg的相关教程请参考下方两个链接:
1.LabelImg介绍与安装教程 2.LabelImg使用教程
将数据转换为TF Records格式
借助Raccon_dataset中的xml_to_csv.py将数据由
XML
格式转为CSV
格式。1
2
3
4
5
6
7
8
9
10
11
12
13# 其中
def main():
image_path = os.path.join(os.getcwd(), 'annotations')
xml_df = xml_to_csv(image_path)
xml_df.to_csv('raccoon_labels.csv', index=None)
print('Successfully converted xml to csv.')
# 修改为:
def main():
for directory in ['train','test']:
image_path = os.path.join(os.getcwd(), 'images/{}'.format(directory))
xml_df = xml_to_csv(image_path)
xml_df.to_csv('data/{}_labels.csv'.format(directory), index=None)
print('Successfully converted xml to csv.')此时目录譬如下方结构:
1
2
3
4
5
6
7
8
9
10
11
12
13.
└── Object-Detection/
├── data/
│ └── test_labels.csv
| └── train_labels.csv
└── images/
| └── test/
| | └── testingimages.jpg
| └── train/
| └── trainingimages.jpg
| └── yourimages.jpg
└── training/
└── xml_to_csv.py借助Raccon_dataset中的generate_tfrecord.py将数据由
CSV
格式转为TF Records
格式。注意:
generate_tfrecord.py
的Todo部分需要与你的.pbtxt文件内的内容一致1
2
3
4
5
6
7# TO-DO replace this with label map
def class_text_to_int(row_label):
if row_label == 'macncheese':
return 1
else:
None
# 此处只有一类执行:
1
2
3# 譬如
python generate_tfrecord.py --csv_input=data/train_labels.csv --output_path=data/train.record
python3 generate_tfrecord.py --csv_input=data/test_labels.csv --output_path=data/test.record
另外在models/research/object_detection/dataset_tools
目录中,官方提供了一些数据转换工具。
配置模型参数
Tensorflow Object Detection API中模型参数、训练参数、评估参数都是在一个config文件中配置。
在配置模型参数的时候,通常有两种方式,一是使用预训练的模型,通过迁移学习(Transfer learning )来学习一个新目标(Object),这种训练方式可以大幅缩减训练的时间,使用少量的数据就可以达到较好的效果。另外一种是从头开始训练,end-to-end。
在models/research/object_detection/samples/configs/
的路径下,官方提供了一些object_detection配置文件的结构。在.config
中搜索所有的PATH_TO_BE_CONFIGURED
,修改为自己数据所存放的路径。另外还有heckpoint的路径、名称,num_classes的数目,label_map_path的路径等,按需修改。
训练
在tensorflow/models/research/
路径下,执行:
1 | # From tensorflow/models/research/ |
其中:--pipeline_config_path
,--model_dir
,--num_train_steps
等按需修改。
使用Tensorboard对过程进行监视
1 | tensorboard --logdir=${YOUR_DIRECTORY}/model_dir |
在浏览器中输入127.0.0.1:6006
观察训练的过程。
参考文献:
[1]https://pythonprogramming.net/custom-objects-tracking-tensorflow-object-detection-api-tutorial/
[2]https://pythonprogramming.net/creating-tfrecord-files-tensorflow-object-detection-api-tutorial/
[3]https://pythonprogramming.net/training-custom-objects-tensorflow-object-detection-api-tutorial/