You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

103 lines
3.2 KiB

import numpy as np
import torch
import time,pdb
def converter(data):
if isinstance(data,torch.Tensor):
data = data.cpu().data.numpy().flatten()
return data.flatten()
def fast_hist(label_pred, label_true,num_classes):
#pdb.set_trace()
hist = np.bincount(num_classes * label_true.astype(int) + label_pred, minlength=num_classes ** 2)
hist = hist.reshape(num_classes, num_classes)
return hist
class Metric_mIoU():
def __init__(self,class_num):
self.class_num = class_num
self.hist = np.zeros((self.class_num,self.class_num))
def update(self,predict,target):
predict,target = converter(predict),converter(target)
self.hist += fast_hist(predict,target,self.class_num)
def reset(self):
self.hist = np.zeros((self.class_num,self.class_num))
def get_miou(self):
miou = np.diag(self.hist) / (
np.sum(self.hist, axis=1) + np.sum(self.hist, axis=0) -
np.diag(self.hist))
miou = np.nanmean(miou)
return miou
def get_acc(self):
acc = np.diag(self.hist) / self.hist.sum(axis=1)
acc = np.nanmean(acc)
return acc
def get(self):
return self.get_miou()
class MultiLabelAcc():
def __init__(self):
self.cnt = 0
self.correct = 0
def reset(self):
self.cnt = 0
self.correct = 0
def update(self,predict,target):
predict,target = converter(predict),converter(target)
self.cnt += len(predict)
self.correct += np.sum(predict==target)
def get_acc(self):
return self.correct * 1.0 / self.cnt
def get(self):
return self.get_acc()
class AccTopk():
def __init__(self,background_classes,k):
self.background_classes = background_classes
self.k = k
self.cnt = 0
self.top5_correct = 0
def reset(self):
self.cnt = 0
self.top5_correct = 0
def update(self,predict,target):
predict,target = converter(predict),converter(target)
self.cnt += len(predict)
background_idx = (predict == self.background_classes) + (target == self.background_classes)
self.top5_correct += np.sum(predict[background_idx] == target[background_idx])
not_background_idx = np.logical_not(background_idx)
self.top5_correct += np.sum(np.absolute(predict[not_background_idx]-target[not_background_idx])<self.k)
def get(self):
return self.top5_correct * 1.0 / self.cnt
def update_metrics(metric_dict, pair_data):
for i in range(len(metric_dict['name'])):
metric_op = metric_dict['op'][i]
data_src = metric_dict['data_src'][i]
metric_op.update(pair_data[data_src[0]], pair_data[data_src[1]])
def reset_metrics(metric_dict):
for op in metric_dict['op']:
op.reset()
if __name__ == '__main__':
# p = np.random.randint(5, size=(800, 800))
# t = np.zeros((800, 800))
# me = Metric_mIoU(5)
# me.update(p,p)
# me.update(p,t)
# me.update(p,p)
# me.update(p,t)
# print(me.get_miou())
# print(me.get_acc())
a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 0])
b = np.array([1, 1, 2, 2, 2, 3, 3, 4, 4, 0])
me = AccTopk(0,5)
me.update(b,a)
print(me.get())