|
2 | 2 |
|
3 | 3 | """ |
4 | 4 | @date: 2020/3/2 上午8:07 |
5 | | -@file: detector.py |
| 5 | +@file: car_detector.py |
6 | 6 | @author: zj |
7 | 7 | @description: 车辆类别检测器 |
8 | 8 | """ |
9 | 9 |
|
10 | | -import os |
11 | 10 | import copy |
12 | 11 | import cv2 |
13 | 12 | import torch |
14 | 13 | import torch.nn as nn |
15 | 14 | from torchvision.models import alexnet |
16 | 15 | import torchvision.transforms as transforms |
17 | 16 | import selectivesearch |
| 17 | +from utils.util import parse_xml |
18 | 18 |
|
19 | | -from utils.util import parse_car_csv |
20 | 19 |
|
21 | | -if __name__ == '__main__': |
22 | | - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| 20 | +def get_device(): |
| 21 | + return torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| 22 | + |
| 23 | + |
| 24 | +def get_transform(): |
23 | 25 | # 数据转换 |
24 | 26 | transform = transforms.Compose([ |
25 | 27 | transforms.ToPILImage(), |
26 | 28 | transforms.Resize((227, 227)), |
| 29 | + transforms.RandomHorizontalFlip(), |
27 | 30 | transforms.ToTensor(), |
28 | | - transforms.Normalize((0.5,), (0.5,)) |
| 31 | + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
29 | 32 | ]) |
| 33 | + return transform |
| 34 | + |
| 35 | + |
| 36 | +def get_model(device=None): |
30 | 37 | # 加载CNN模型 |
31 | 38 | model = alexnet() |
32 | 39 | num_classes = 2 |
33 | 40 | num_features = model.classifier[6].in_features |
34 | 41 | model.classifier[6] = nn.Linear(num_features, num_classes) |
35 | | - model.load_state_dict(torch.load('./models/linear_svm_alexnet_car.pth')) |
| 42 | + model.load_state_dict(torch.load('./models/linear_svm_alexnet_car_4.pth')) |
36 | 43 | model.eval() |
37 | | - # print(model) |
38 | | - model = model.to(device) |
| 44 | + |
39 | 45 | # 取消梯度追踪 |
40 | 46 | for param in model.parameters(): |
41 | 47 | param.requires_grad = False |
| 48 | + if device: |
| 49 | + model = model.to(device) |
| 50 | + |
| 51 | + return model |
| 52 | + |
| 53 | + |
| 54 | +if __name__ == '__main__': |
| 55 | + device = get_device() |
| 56 | + transform = get_transform() |
| 57 | + model = get_model(device=device) |
| 58 | + |
42 | 59 | # 创建selectivesearch对象 |
43 | 60 | gs = selectivesearch.get_selective_search() |
44 | 61 |
|
45 | | - car_root_dir = './data/voc_car/' |
46 | | - val_root_dir = os.path.join(car_root_dir, 'val') |
47 | | - samples = parse_car_csv(val_root_dir) |
48 | | - |
49 | | - for sample_name in samples: |
50 | | - jpeg_path = os.path.join(val_root_dir, 'JPEGImages', sample_name + ".jpg") |
51 | | - annotation_path = os.path.join(val_root_dir, 'Annotations', sample_name + ".xml") |
| 62 | + test_img_path = './data/voc_car/val/JPEGImages/000007.jpg' |
| 63 | + test_xml_path = './data/voc_car/val/Annotations/000007.xml' |
52 | 64 |
|
53 | | - img = cv2.imread(jpeg_path) |
54 | | - dst = copy.deepcopy(img) |
| 65 | + img = cv2.imread(test_img_path) |
| 66 | + dst = copy.deepcopy(img) |
55 | 67 |
|
56 | | - # 候选区域建议 |
57 | | - selectivesearch.config(gs, img, strategy='f') |
58 | | - rects = selectivesearch.get_rects(gs) |
59 | | - print('候选区域建议数目: %d' % len(rects)) |
| 68 | + bndboxs = parse_xml(test_xml_path) |
| 69 | + for bndbox in bndboxs: |
| 70 | + xmin, ymin, xmax, ymax = bndbox |
| 71 | + cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 255, 0), thickness=2) |
60 | 72 |
|
61 | | - rects_transform = transform(rects) |
62 | | - print(rects_transform.shape) |
63 | | - exit(0) |
| 73 | + # 候选区域建议 |
| 74 | + selectivesearch.config(gs, img, strategy='f') |
| 75 | + rects = selectivesearch.get_rects(gs) |
| 76 | + print('候选区域建议数目: %d' % len(rects)) |
64 | 77 |
|
65 | | - for rect in rects: |
66 | | - xmin, ymin, xmax, ymax = rect |
67 | | - rect_img = img[ymin:ymax, xmin:xmax] |
| 78 | + for rect in rects: |
| 79 | + xmin, ymin, xmax, ymax = rect |
| 80 | + rect_img = img[ymin:ymax, xmin:xmax] |
68 | 81 |
|
69 | | - rect_transform = transform(rect_img).to(device) |
70 | | - output = model(rect_transform.unsqueeze(0))[0] |
| 82 | + rect_transform = transform(rect_img).to(device) |
| 83 | + output = model(rect_transform.unsqueeze(0))[0] |
71 | 84 |
|
72 | | - if torch.argmax(output).item() == 1: |
73 | | - """ |
74 | | - 预测为汽车 |
75 | | - """ |
76 | | - cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 0, 255), thickness=1) |
77 | | - print(rect, output) |
| 85 | + if torch.argmax(output).item() == 1: |
| 86 | + """ |
| 87 | + 预测为汽车 |
| 88 | + """ |
| 89 | + cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 0, 255), thickness=2) |
| 90 | + print(rect, output) |
78 | 91 |
|
79 | | - cv2.imshow('img', dst) |
80 | | - cv2.waitKey(0) |
| 92 | + cv2.imshow('img', dst) |
| 93 | + cv2.waitKey(0) |
0 commit comments