99
1010import os
1111import glob
12- import time
13- import shutil
1412import cv2
1513import numpy as np
1614import torch
1715import torchvision .transforms as transforms
1816
1917from utils import file
2018from utils import util
21- from utils import draw
22- from models .location_dataset import LocationDataset
19+ from utils import voc_map
2320from models .yolo_v1 import YOLO_v1
2421
2522S = 7
2825
2926cate_list = ['cucumber' , 'eggplant' , 'mushroom' ]
3027
28+ dst_root_dir = './data/outputs'
29+ dst_target_dir = os .path .join (dst_root_dir , 'targets' )
30+ dst_pred_dir = os .path .join (dst_root_dir , 'preds' )
31+ dst_img_dir = os .path .join (dst_root_dir , 'imgs' )
32+ tmp_json_dir = os .path .join (dst_root_dir , '.tmp_files' )
33+
34+ file .check_dir (dst_root_dir )
35+ file .check_dir (dst_target_dir )
36+ file .check_dir (dst_pred_dir )
37+ file .check_dir (dst_img_dir )
38+ file .check_dir (tmp_json_dir )
39+
3140
3241def get_transform ():
3342 transform = transforms .Compose ([
@@ -41,9 +50,11 @@ def get_transform():
4150
4251
4352def load_data (root_dir ):
44- img_path_list = glob .glob (os .path .join (root_dir , '*.jpg' ))
45- annotation_path_list = [os .path .join (root_dir , os .path .splitext (os .path .basename (img_path ))[0 ] + ".xml" )
46- for img_path in img_path_list ]
53+ img_path_list = glob .glob (os .path .join (root_dir , 'imgs' , '*.jpg' ))
54+ img_path_list .sort ()
55+ annotation_path_list = [
56+ os .path .join (root_dir , 'annotations' , os .path .splitext (os .path .basename (img_path ))[0 ] + ".xml" )
57+ for img_path in img_path_list ]
4758
4859 return img_path_list , annotation_path_list
4960
@@ -137,16 +148,6 @@ def save_data(img_name, img, target_cates, target_bboxs, pred_cates, pred_probs,
137148 :param pred_probs: 预测边界框置信度
138149 :param pred_bboxs: 预测边界框坐标
139150 """
140- dst_root_dir = './data/outputs'
141- dst_target_dir = os .path .join (dst_root_dir , 'targets' )
142- dst_pred_dir = os .path .join (dst_root_dir , 'preds' )
143- dst_img_dir = os .path .join (dst_root_dir , 'imgs' )
144-
145- file .check_dir (dst_root_dir )
146- file .check_dir (dst_target_dir )
147- file .check_dir (dst_pred_dir )
148- file .check_dir (dst_img_dir )
149-
150151 img_path = os .path .join (dst_img_dir , img_name + ".png" )
151152 cv2 .imwrite (img_path , img )
152153 annotation_path = os .path .join (dst_target_dir , img_name + ".txt" )
@@ -176,7 +177,7 @@ def save_data(img_name, img, target_cates, target_bboxs, pred_cates, pred_probs,
176177 model = load_model (device )
177178
178179 transform = get_transform ()
179- img_path_list , annotation_path_list = load_data ('./data/training_images ' )
180+ img_path_list , annotation_path_list = load_data ('./data/location_dataset ' )
180181 # print(img_path_list)
181182
182183 N = len (img_path_list )
@@ -217,4 +218,5 @@ def save_data(img_name, img, target_cates, target_bboxs, pred_cates, pred_probs,
217218 img_name = os .path .splitext (os .path .basename (img_path ))[0 ]
218219 save_data (img_name , data_dict ['src' ], data_dict ['name_list' ], data_dict ['bndboxs' ],
219220 pred_cates , pred_cate_probs , pred_bboxs )
220- print ('done' )
221+ print ('compute mAP' )
222+ voc_map .voc_map (dst_target_dir , dst_pred_dir , tmp_json_dir )
0 commit comments