You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
matching_dependency/bert.py

28 lines
1.4 KiB

1 year ago
import csv
import pandas as pd
import json
import sentence_transformers.util
import torch
from sentence_transformers import SentenceTransformer
from torch import nn
if __name__ == '__main__':
# model = SentenceTransformer('E:\\Data\\Research\\Models\\roberta-large-nli-stsb-mean-tokens')
# sentences = ['公积金转入深圳', '公积金转出深圳', None, None, 114514, 114514, 1919810]
# embedding = model.encode(sentences, device='cuda')
# outcome1 = sentence_transformers.util.cos_sim(embedding[4], embedding[5])
# outcome2 = sentence_transformers.util.cos_sim(embedding[4], embedding[6])
# print(outcome1.item())
# print(outcome2.item())
train = pd.read_csv(r'E:\Data\Research\Datasets\DeepMatcher dataset\Dirty\Walmart-Amazon\train.csv', encoding='ISO-8859-1')
valid = pd.read_csv(r'E:\Data\Research\Datasets\DeepMatcher dataset\Dirty\Walmart-Amazon\valid.csv', encoding='ISO-8859-1')
test = pd.read_csv(r'E:\Data\Research\Datasets\DeepMatcher dataset\Dirty\Walmart-Amazon\test.csv', encoding='ISO-8859-1')
train = train[train['label'] == 1]
valid = valid[valid['label'] == 1]
test = test[test['label'] == 1]
matches = pd.concat([train, valid, test])
matches.drop(columns=['label'], inplace=True)
matches = matches.sort_values(by='ltable_id')
matches.to_csv(r'E:\Data\Research\Projects\matching_dependency\datasets\Walmart-Amazon_dirty\matches.csv', sep=',', index=False, header=True)