Skip to content
This repository was archived by the owner on Nov 2, 2024. It is now read-only.

Commit becee1e

Browse files
committed
perf(svm): 判断负样本是否重复添加
1 parent 738882d commit becee1e

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

py/linear_svm.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,15 @@ def hinge_loss(outputs, labels):
9393
return loss
9494

9595

96-
def add_hard_negatives(target_list, negative_list):
97-
for item in target_list:
98-
if item not in negative_list:
96+
def add_hard_negatives(hard_negative_list, negative_list, add_negative_list):
97+
for item in hard_negative_list:
98+
if add_negative_list is None:
99+
# 第一次添加负样本
99100
negative_list.append(item)
100-
101-
return negative_list
101+
add_negative_list.append(list(item['rect']))
102+
if item['rect'] not in add_negative_list:
103+
negative_list.append(item)
104+
add_negative_list.append(list(item['rect'])
102105

103106

104107
def get_hard_negatives(preds, cache_dicts):
@@ -195,6 +198,8 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
195198

196199
# 获取训练数据集的负样本集
197200
negative_list = train_dataset.get_negatives()
201+
# 记录后续增加的负样本
202+
add_negative_list = data_loaders['add_negative']
198203

199204
running_corrects = 0
200205
# Iterate over data.
@@ -212,7 +217,7 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
212217
running_corrects += torch.sum(preds == labels.data)
213218

214219
hard_negative_list, easy_neagtive_list = get_hard_negatives(preds.cpu().numpy(), cache_dicts)
215-
negative_list = add_hard_negatives(hard_negative_list, negative_list)
220+
add_hard_negatives(hard_negative_list, negative_list, add_negative_list)
216221

217222
remain_acc = running_corrects.double() / data_sizes[phase]
218223
print('remiam negative size: {}, acc: {:.4f}'.format(len(remain_negative_list), remain_acc))
@@ -223,6 +228,7 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
223228
batch_positive, batch_negative)
224229
data_loaders['train'] = DataLoader(train_dataset, batch_size=batch_total, sampler=tmp_sampler,
225230
num_workers=8, drop_last=True)
231+
data_loaders['add_negative'] = add_negative_list
226232
# 重置数据集大小
227233
data_sizes['train'] = len(tmp_sampler)
228234

0 commit comments

Comments
 (0)