使用TensorFlow目标检测框架训练检测自己的模型
大致流程可以分为:
采集数据集
- 类别数量
- 文件结构
标定数据集(按照标准生成xml文件,标定工具为(labelImg))
- 生成tfrecord文件
- 训练
- 导出冻结图(.pb)
- 根据冻结图推理
通常文件结构为:
- images(包含所有图像的目录)
- annotations
- xmls目录下(包含标定的xmls)文件
- trainval.txt(类别和类别的id)
- labels.pbtxt(label和id的映射文件) 采集数据集的代码:
# i:保存图像的路径
# n:保存文件的基础名字
# f:文件格式
# 使用方式python get_dataset.py -i imagepath -n test -f png
import argparse
import datetime
import imutils
import time
import cv2
ap = argparse.ArgumentParser()
ap.add_argument("-i","--imagepath",help="path to saved image")
ap.add_argument("-n","--imagename",help="image's name")
ap.add_argument("-f","--imageformat",help="format of image")
args = vars(ap.parse_args())
camera = cv2.VideoCapture(1)
time.sleep(0.25)
count = 0
image_saved_path = args.get('imagepath')
image_name = args.get('imagename')
image_format = args.get("imageformat")
print("保存路径:%s保存名称:%s保存格式%s"%(image_saved_path,image_name,image_format))
while True:
(grabbed,frame) = camera.read()
if not grabbed:
break
frame = imutils.resize(frame)
cv2.imshow('real sence',frame)
k = cv2.waitKey(1)
if k == 27:
break
elif k == ord("s"):
count = count + 1
issave = cv2.imwrite(image_saved_path+image_name+str(count)+'.'+image_format,frame)
if issave:
print("保存成功"+str(count)+"张图片")
else:
print("保存失败请检查")
else:
continue
camera.release()
cv2.destroyAllWindows()
生成tfrecord的代码: