You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
28 lines
1.4 KiB
28 lines
1.4 KiB
import csv
|
|
|
|
import pandas as pd
|
|
import json
|
|
import sentence_transformers.util
|
|
import torch
|
|
from sentence_transformers import SentenceTransformer
|
|
from torch import nn
|
|
|
|
if __name__ == '__main__':
|
|
# model = SentenceTransformer('E:\\Data\\Research\\Models\\roberta-large-nli-stsb-mean-tokens')
|
|
# sentences = ['公积金转入深圳', '公积金转出深圳', None, None, 114514, 114514, 1919810]
|
|
# embedding = model.encode(sentences, device='cuda')
|
|
# outcome1 = sentence_transformers.util.cos_sim(embedding[4], embedding[5])
|
|
# outcome2 = sentence_transformers.util.cos_sim(embedding[4], embedding[6])
|
|
# print(outcome1.item())
|
|
# print(outcome2.item())
|
|
train = pd.read_csv(r'E:\Data\Research\Datasets\DeepMatcher dataset\Dirty\Walmart-Amazon\train.csv', encoding='ISO-8859-1')
|
|
valid = pd.read_csv(r'E:\Data\Research\Datasets\DeepMatcher dataset\Dirty\Walmart-Amazon\valid.csv', encoding='ISO-8859-1')
|
|
test = pd.read_csv(r'E:\Data\Research\Datasets\DeepMatcher dataset\Dirty\Walmart-Amazon\test.csv', encoding='ISO-8859-1')
|
|
train = train[train['label'] == 1]
|
|
valid = valid[valid['label'] == 1]
|
|
test = test[test['label'] == 1]
|
|
matches = pd.concat([train, valid, test])
|
|
matches.drop(columns=['label'], inplace=True)
|
|
matches = matches.sort_values(by='ltable_id')
|
|
matches.to_csv(r'E:\Data\Research\Projects\matching_dependency\datasets\Walmart-Amazon_dirty\matches.csv', sep=',', index=False, header=True)
|