diff --git a/draw/draw_confidence_histogram.py b/draw/draw_confidence_histogram.py index bfb24ee..2c18f8f 100644 --- a/draw/draw_confidence_histogram.py +++ b/draw/draw_confidence_histogram.py @@ -7,48 +7,50 @@ from pyecharts.globals import ThemeType if __name__ == '__main__': outcome_dir = r'E:\Data\Research\Outcome' - configs_dir = r'\Magellan+Smac+roberta-large-nli-stsb-mean-tokens+inter-0.5' + inter_list = ['0', '0.5', '0.7', '0.9', '1'] + configs_dir = r'\Magellan+Smac+roberta-large-nli-stsb-mean-tokens+inter-' datasets_list = os.listdir(outcome_dir) for _ in datasets_list: - path = outcome_dir + rf'\{_}' + configs_dir - statistics_files = os.listdir(path) - length = 0 - for file in statistics_files: - if file.startswith('predictions'): - preds = pd.read_csv(path + rf'\{file}', encoding='ISO-8859-1') - preds = preds[['predicted', 'confidence']] - preds = preds.astype(float) - preds = preds[preds['predicted'] == 1.0] - length = len(preds) - li = [] - zeros = len(preds[preds['confidence'] == 0]) - dot_02 = len(preds[(preds['confidence'] > 0) & (preds['confidence'] <= 0.2)]) - dot_24 = len(preds[(preds['confidence'] > 0.2) & (preds['confidence'] <= 0.4)]) - dot_46 = len(preds[(preds['confidence'] > 0.4) & (preds['confidence'] <= 0.6)]) - dot_68 = len(preds[(preds['confidence'] > 0.6) & (preds['confidence'] <= 0.8)]) - dot_80 = len(preds[(preds['confidence'] > 0.8) & (preds['confidence'] <= 1.0)]) - for number in [zeros, dot_02, dot_24, dot_46, dot_68, dot_80]: - li.append(round(number * 100 / length, ndigits=3)) + for inter in inter_list: + path = outcome_dir + rf'\{_}' + configs_dir + inter + statistics_files = os.listdir(path) + length = 0 + for file in statistics_files: + if file.startswith('predictions'): + preds = pd.read_csv(path + rf'\{file}', encoding='ISO-8859-1') + preds = preds[['predicted', 'confidence']] + preds = preds.astype(float) + preds = preds[preds['predicted'] == 1.0] + length = len(preds) + li = [] + zeros = len(preds[preds['confidence'] == 0]) + dot_02 = len(preds[(preds['confidence'] > 0) & (preds['confidence'] <= 0.2)]) + dot_24 = len(preds[(preds['confidence'] > 0.2) & (preds['confidence'] <= 0.4)]) + dot_46 = len(preds[(preds['confidence'] > 0.4) & (preds['confidence'] <= 0.6)]) + dot_68 = len(preds[(preds['confidence'] > 0.6) & (preds['confidence'] <= 0.8)]) + dot_80 = len(preds[(preds['confidence'] > 0.8) & (preds['confidence'] <= 1.0)]) + for number in [zeros, dot_02, dot_24, dot_46, dot_68, dot_80]: + li.append(round(number * 100 / length, ndigits=3)) - c = ( - Bar(init_opts=opts.InitOpts(theme=ThemeType.WALDEN)) - .add_xaxis(['conf=0', '0