Skip to content

Commit af9fb0e

Browse files
committed
feat(map): voc map计算
1 parent 0aa8da3 commit af9fb0e

File tree

4 files changed

+348
-22
lines changed

4 files changed

+348
-22
lines changed

py/batch_detect.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,14 @@
99

1010
import os
1111
import glob
12-
import time
13-
import shutil
1412
import cv2
1513
import numpy as np
1614
import torch
1715
import torchvision.transforms as transforms
1816

1917
from utils import file
2018
from utils import util
21-
from utils import draw
22-
from models.location_dataset import LocationDataset
19+
from utils import voc_map
2320
from models.yolo_v1 import YOLO_v1
2421

2522
S = 7
@@ -28,6 +25,18 @@
2825

2926
cate_list = ['cucumber', 'eggplant', 'mushroom']
3027

28+
dst_root_dir = './data/outputs'
29+
dst_target_dir = os.path.join(dst_root_dir, 'targets')
30+
dst_pred_dir = os.path.join(dst_root_dir, 'preds')
31+
dst_img_dir = os.path.join(dst_root_dir, 'imgs')
32+
tmp_json_dir = os.path.join(dst_root_dir, '.tmp_files')
33+
34+
file.check_dir(dst_root_dir)
35+
file.check_dir(dst_target_dir)
36+
file.check_dir(dst_pred_dir)
37+
file.check_dir(dst_img_dir)
38+
file.check_dir(tmp_json_dir)
39+
3140

3241
def get_transform():
3342
transform = transforms.Compose([
@@ -41,9 +50,11 @@ def get_transform():
4150

4251

4352
def load_data(root_dir):
44-
img_path_list = glob.glob(os.path.join(root_dir, '*.jpg'))
45-
annotation_path_list = [os.path.join(root_dir, os.path.splitext(os.path.basename(img_path))[0] + ".xml")
46-
for img_path in img_path_list]
53+
img_path_list = glob.glob(os.path.join(root_dir, 'imgs', '*.jpg'))
54+
img_path_list.sort()
55+
annotation_path_list = [
56+
os.path.join(root_dir, 'annotations', os.path.splitext(os.path.basename(img_path))[0] + ".xml")
57+
for img_path in img_path_list]
4758

4859
return img_path_list, annotation_path_list
4960

@@ -137,16 +148,6 @@ def save_data(img_name, img, target_cates, target_bboxs, pred_cates, pred_probs,
137148
:param pred_probs: 预测边界框置信度
138149
:param pred_bboxs: 预测边界框坐标
139150
"""
140-
dst_root_dir = './data/outputs'
141-
dst_target_dir = os.path.join(dst_root_dir, 'targets')
142-
dst_pred_dir = os.path.join(dst_root_dir, 'preds')
143-
dst_img_dir = os.path.join(dst_root_dir, 'imgs')
144-
145-
file.check_dir(dst_root_dir)
146-
file.check_dir(dst_target_dir)
147-
file.check_dir(dst_pred_dir)
148-
file.check_dir(dst_img_dir)
149-
150151
img_path = os.path.join(dst_img_dir, img_name + ".png")
151152
cv2.imwrite(img_path, img)
152153
annotation_path = os.path.join(dst_target_dir, img_name + ".txt")
@@ -176,7 +177,7 @@ def save_data(img_name, img, target_cates, target_bboxs, pred_cates, pred_probs,
176177
model = load_model(device)
177178

178179
transform = get_transform()
179-
img_path_list, annotation_path_list = load_data('./data/training_images')
180+
img_path_list, annotation_path_list = load_data('./data/location_dataset')
180181
# print(img_path_list)
181182

182183
N = len(img_path_list)
@@ -217,4 +218,5 @@ def save_data(img_name, img, target_cates, target_bboxs, pred_cates, pred_probs,
217218
img_name = os.path.splitext(os.path.basename(img_path))[0]
218219
save_data(img_name, data_dict['src'], data_dict['name_list'], data_dict['bndboxs'],
219220
pred_cates, pred_cate_probs, pred_bboxs)
220-
print('done')
221+
print('compute mAP')
222+
voc_map.voc_map(dst_target_dir, dst_pred_dir, tmp_json_dir)

py/lib/utils/draw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def plot_bboxs(img, bndboxs, name_list, pred_boxs, pred_cates, pred_probs):
5959

6060
xmin, ymin, xmax, ymax = np.array(bbox, dtype=np.int)
6161
cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), (0, 0, 255), thickness=1)
62-
cv2.putText(dst, '%s_%f' % (cate_list[cate], prob), (xmin, ymax), 1, cv2.FONT_HERSHEY_PLAIN, (0, 0, 255),
62+
cv2.putText(dst, '%s_%.3f' % (cate_list[cate], prob), (xmin, ymin), 1, cv2.FONT_HERSHEY_PLAIN, (0, 0, 255),
6363
thickness=1)
6464

6565
return dst

py/lib/utils/file.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
import xmltodict
1212
import numpy as np
1313
import torch
14+
import shutil
15+
import json
16+
import glob
1417

1518

1619
def check_dir(data_dir):
17-
if not os.path.exists(data_dir):
18-
os.mkdir(data_dir)
20+
if os.path.exists(data_dir):
21+
shutil.rmtree(data_dir)
22+
os.mkdir(data_dir)
1923

2024

2125
def parse_location_xml(xml_path):
@@ -72,6 +76,79 @@ def parse_location_xml(xml_path):
7276
return np.array(bndboxs), name_list
7377

7478

79+
def file_lines_to_list(path):
80+
"""
81+
Convert the lines of a file to a list
82+
"""
83+
# open txt file lines to a list
84+
with open(path) as f:
85+
content = f.readlines()
86+
# remove whitespace characters like `\n` at the end of each line
87+
content = [x.strip() for x in content]
88+
return content
89+
90+
91+
def parse_ground_truth(ground_truth_dir, tmp_json_dir):
92+
"""
93+
解析每个图片的真值边界框,以格式{"cate": "cucumber", "bbox": [23, 42, 206, 199], "used": true}保存
94+
"""
95+
gt_path_list = glob.glob(os.path.join(ground_truth_dir, '*.txt'))
96+
97+
# 统计每类的真值标注框数量
98+
gt_per_classes_dict = {}
99+
for gt_path in gt_path_list:
100+
json_list = list()
101+
lines = file_lines_to_list(gt_path)
102+
for line in lines:
103+
cate, xmin, ymin, xmax, ymax = line.split(' ')
104+
json_list.append({'cate': cate, 'bbox': [int(xmin), int(ymin), int(xmax), int(ymax)], 'used': False})
105+
106+
if gt_per_classes_dict.get(cate) is None:
107+
gt_per_classes_dict[cate] = 1
108+
else:
109+
gt_per_classes_dict[cate] += 1
110+
# 保存
111+
name = os.path.splitext(os.path.basename(gt_path))[0]
112+
json_path = os.path.join(tmp_json_dir, name + ".json")
113+
with open(json_path, 'w') as f:
114+
json.dump(json_list, f)
115+
116+
return gt_per_classes_dict
117+
118+
119+
def parse_detection_results(detection_result_dir, tmp_json_dir):
120+
"""
121+
解析每个类别的预测边界框,以格式{"confidence": "0.999", "file_id": "cucumber_61", "bbox": [16, 42, 225, 163]}保存
122+
"""
123+
dr_path_list = glob.glob(os.path.join(detection_result_dir, '*.txt'))
124+
125+
# 保存每个类别的预测边界框信息
126+
dt_per_classes_dict = dict()
127+
for dr_path in dr_path_list:
128+
lines = file_lines_to_list(dr_path)
129+
name = os.path.splitext(os.path.basename(dr_path))[0]
130+
131+
for line in lines:
132+
cate, confidence, xmin, ymin, xmax, ymax = line.split(' ')
133+
if dt_per_classes_dict.get(cate) is None:
134+
dt_per_classes_dict[cate] = [
135+
{'confidence': confidence, 'file_id': name, 'bbox': [int(xmin), int(ymin), int(xmax), int(ymax)]}]
136+
else:
137+
dt_per_classes_dict[cate].append(
138+
{'confidence': confidence, 'file_id': name, 'bbox': [int(xmin), int(ymin), int(xmax), int(ymax)]})
139+
140+
# 保存
141+
for key, value in dt_per_classes_dict.items():
142+
# 按置信度递减排序
143+
value.sort(key=lambda x: float(x['confidence']), reverse=True)
144+
145+
json_path = os.path.join(tmp_json_dir, key + "_dt.json")
146+
with open(json_path, 'w') as f:
147+
json.dump(value, f)
148+
149+
return dt_per_classes_dict
150+
151+
75152
def save_model(model_weights, model_save_path):
76153
torch.save(model_weights, model_save_path)
77154

0 commit comments

Comments
 (0)