You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

291 lines
268 KiB

5 months ago
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"item_bert_emb_dict = pickle.load(open(\"../../results/shixun/sample/shixuns_bert_emb_dict.pkl\", 'rb'))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys([43, 49, 50, 51, 53, 54, 56, 57, 58, 60, 61, 62, 63, 64, 65, 67, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 80, 81, 82, 83, 84, 85, 88, 89, 90, 91, 92, 93, 95, 96, 97, 99, 100, 101, 102, 103, 104, 108, 110, 118, 119, 120, 121, 122, 123, 132, 133, 136, 137, 140, 141, 142, 143, 144, 145, 148, 149, 152, 153, 154, 155, 156, 157, 158, 161, 164, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 180, 181, 183, 186, 187, 189, 190, 192, 193, 194, 195, 196, 203, 206, 207, 208, 209, 211, 213, 218, 220, 221, 222, 223, 224, 225, 229, 230, 231, 232, 233, 234, 235, 236, 249, 251, 252, 256, 273, 275, 277, 278, 279, 299, 300, 305, 307, 308, 309, 312, 313, 328, 329, 330, 333, 340, 341, 342, 344, 346, 347, 350, 352, 353, 356, 361, 366, 367, 368, 369, 375, 377, 378, 379, 382, 383, 386, 387, 398, 402, 403, 405, 406, 408, 418, 421, 424, 425, 426, 427, 428, 436, 440, 448, 449, 451, 452, 453, 454, 456, 457, 460, 462, 463, 464, 465, 466, 469, 473, 474, 476, 477, 478, 481, 483, 484, 485, 487, 490, 494, 495, 497, 498, 499, 503, 504, 505, 506, 511, 515, 522, 523, 524, 526, 527, 528, 534, 535, 536, 538, 539, 542, 543, 544, 548, 555, 559, 563, 570, 571, 572, 573, 574, 576, 577, 592, 593, 594, 597, 602, 603, 606, 607, 608, 609, 611, 613, 618, 619, 620, 623, 625, 626, 627, 628, 629, 630, 633, 638, 639, 641, 642, 644, 647, 649, 651, 652, 659, 660, 669, 673, 675, 677, 678, 692, 693, 695, 700, 711, 717, 719, 720, 721, 722, 723, 724, 725, 738, 739, 741, 743, 744, 746, 747, 748, 749, 750, 752, 753, 756, 757, 759, 760, 761, 762, 763, 766, 767, 768, 769, 771, 772, 774, 776, 778, 779, 782, 783, 784, 785, 786, 788, 789, 793, 794, 795, 796, 797, 798, 799, 800, 802, 803, 805, 806, 807, 809, 815, 816, 817, 818, 819, 820, 822, 823, 826, 827, 828, 829, 834, 836, 837, 838, 839, 840, 841, 842, 848, 849, 850, 867, 868, 870, 871, 872, 875, 876, 877, 879, 880, 881, 882, 883, 885, 886, 887, 888, 889, 890, 891, 892, 893, 895, 897, 899, 900, 907, 908, 909, 910, 911, 912, 918, 919, 920, 922, 924, 926, 928, 929, 930, 931, 933, 934, 937, 939, 946, 947, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 993, 994, 995, 1000, 1001, 1005, 1006, 1007, 1008, 1009, 1011, 1012, 1013, 1014, 1015, 1016, 1020, 1022, 1023, 1025, 1026, 1027, 1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1039, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1057, 1060, 1061, 1063, 1064, 1065, 1067, 1074, 1075, 1076, 1078, 1079, 1080, 1081, 1082, 1084, 1091, 1092, 1093, 1094, 1095, 1096, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1111, 1112, 1113, 1116, 1117, 1119, 1122, 1124, 1125, 1127, 1129, 1130, 1133, 1134, 1137, 1139, 1140, 1142, 1144, 1145, 1146, 1147, 1148, 1150, 1151, 1155, 1156, 1157, 1159, 1160, 1161, 1162, 1163, 1164, 1167, 1168, 1169, 1170, 1171, 1172, 1174, 1175, 1176, 1177, 1178, 1180, 1181, 1182, 1183, 1185, 1186, 1187, 1188, 1190, 1191, 1193, 1195, 1197, 1198, 1201, 1202, 1203, 1207, 1208, 1210, 1213, 1215, 1216, 1218, 1219, 1220, 1221, 1222, 1223, 1243, 1244, 1245, 1250, 1253, 1254, 1255, 1256, 1258, 1259, 1260, 1262, 1263, 1264, 1265, 1267, 1268, 1270, 1271, 1275, 1276, 1277, 1278, 1279, 1285, 1287, 1289, 1290, 1291, 1294, 1300, 1301, 1304, 1305, 1307, 1308, 1310, 1314, 1315, 1318, 1320, 1323, 1326, 1327, 1331, 1333, 1342, 1350, 1352, 1370, 1373, 1377, 1381, 1383, 1386, 1389, 1394, 1395, 1397, 1398, 1403, 1404, 1405, 1406, 1407, 1408, 1409, 1419, 1423, 1426, 1448, 1449, 1452, 1453, 1454, 1455, 1456, 1461, 1462, 1466, 1467, 1469, 1470, 1472, 1473, 1474, 1475, 1476, 1479, 1480, 1481, 1483, 1487, 1488, 1492, 1493, 1494, 1495, 1497, 1499, 1500, 1501, 1502, 1503, 1507, 1508, 1509, 1515, 1516, 1517, 1518, 1519, 1522, 1526, 1527, 1528, 1530, 1537, 1540, 1542, 1547, 1548, 1549, 1550, 1552, 1564, 1566, 1568, 1569, 1570, 1571, 1576, 1586, 1591, 1592, 1594, 1600, 1603, 1604, 1605, 1606, 1607, 1608, 1609, 1621, 1622, 1624, 1625, 1628, 1630, 1631, 1632, 1634, 1635, 1642, 1643, 1647, 1648, 1649, 1650, 1651, 1
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"item_bert_emb_dict.keys()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"shixuns_emb=item_bert_emb_dict"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.04126396, 0.04197039, 0.03578058, -0.00456123, 0.02694773,\n",
" 0.02667332, -0.02273138, 0.01312323, 0.0417384 , -0.02911825,\n",
" 0.04198574, 0.04161852, 0.03541076, -0.04089891, 0.04198033,\n",
" -0.04170334, -0.0166467 , 0.04194591, 0.04156616, 0.02320999,\n",
" 0.0416895 , -0.041989 , -0.04165329, -0.03914439, 0.0067468 ,\n",
" 0.02596774, 0.03412744, -0.0081602 , -0.04198273, 0.03937697,\n",
" 0.03644141, 0.04186572, 0.04105975, -0.04198475, -0.0370855 ,\n",
" 0.0100278 , -0.04090442, 0.04121647, -0.04045089, -0.02107013,\n",
" -0.04068257, 0.01064464, 0.00713908, -0.03255539, -0.04196075,\n",
" 0.00983701, -0.04198876, -0.04176414, 0.03463598, 0.04196657,\n",
" -0.01624174, -0.04197881, 0.03254087, -0.0360364 , -0.02825117,\n",
" 0.03540781, -0.04135983, 0.03376076, 0.04198885, 0.03603492,\n",
" 0.04195728, -0.03960222, -0.02160683, -0.04101721, 0.04198818,\n",
" -0.04181207, -0.00218115, -0.00350076, 0.04198389, 0.04197426,\n",
" -0.04039839, 0.03370974, 0.04198639, -0.01080438, 0.03208927,\n",
" 0.04109811, -0.03798063, 0.01973926, -0.04197707, -0.04042572,\n",
" 0.04198582, 0.04137517, -0.03217209, 0.03209328, -0.03233838,\n",
" -0.04195601, -0.04198622, 0.04087406, -0.02032096, 0.04166042,\n",
" 0.04172924, -0.04144742, -0.04198722, 0.04054035, -0.04182615,\n",
" -0.04091164, 0.03532492, 0.03901083, -0.0259569 , -0.00172883,\n",
" -0.03467132, 0.03808066, -0.04027368, -0.03852938, 0.04135282,\n",
" 0.0418681 , 0.03186438, -0.04031213, 0.04192768, 0.03089509,\n",
" -0.04198885, -0.03153783, -0.04150786, -0.02570644, -0.03937201,\n",
" 0.04197528, 0.04152225, 0.01586326, 0.04171605, -0.03840322,\n",
" 0.00729388, -0.04194935, -0.02603213, -0.03478069, 0.04160694,\n",
" 0.04197155, 0.03831626, -0.04197625, 0.0416682 , 0.04198779,\n",
" 0.03933371, 0.03543851, -0.03822474, 0.0401117 , 0.03281059,\n",
" -0.03754264, 0.0317249 , 0.00940946, 0.04198861, 0.03731949,\n",
" 0.02482679, -0.03630931, 0.04194146, -0.04146268, 0.04193872,\n",
" -0.04196607, 0.04053254, -0.04198255, -0.04185054, 0.04196976,\n",
" 0.04141143, 0.0419884 , -0.03408257, 0.04198183, -0.04186241,\n",
" -0.04122971, 0.04195711, -0.00407482, 0.03515287, -0.04197061,\n",
" 0.04009439, -0.02275989, -0.0385541 , 0.03808279, -0.0419889 ,\n",
" 0.04195229, -0.02864223, 0.04198409, 0.04184332, -0.03735469,\n",
" -0.04166716, -0.04105362, 0.04166531, -0.04015122, -0.00027225,\n",
" 0.03903495, 0.03633447, 0.03427457, -0.02986496, -0.04094725,\n",
" 0.03943654, -0.01898882, -0.04193359, 0.04160382, -0.01473461,\n",
" -0.0046962 , -0.03931821, 0.0345024 , 0.04177358, 0.04030373,\n",
" -0.01875442, 0.04198079, -0.01925659, 0.04195868, 0.03621402,\n",
" -0.01915105, -0.01232733, -0.04181978, -0.04191966, 0.0372053 ,\n",
" 0.0419848 , -0.02579133, -0.03965329, 0.04021261, -0.04197234,\n",
" 0.04034818, -0.00077876, 0.03246009, -0.04195715, -0.04190793,\n",
" 0.04198784, -0.04132089, -0.03951168, 0.03418558, -0.0394355 ,\n",
" -0.0140977 , -0.04170181, 0.00830208, 0.04057429, -0.03637843,\n",
" -0.03626586, 0.00491304, -0.04176958, 0.04195475, -0.04189665,\n",
" 0.03053708, 0.03490473, 0.04198675, 0.03573518, 0.00484118,\n",
" -0.03846973, 0.04138119, 0.0148096 , -0.04198892, 0.0398224 ,\n",
" -0.04185105, -0.03467244, 0.04195027, -0.04187242, 0.03901177,\n",
" 0.04198557, 0.0127536 , 0.04198862, 0.01660896, -0.04144615,\n",
" -0.04162048, 0.04198782, 0.04189807, 0.04194672, -0.04117197,\n",
" -0.04154793, -0.00861619, -0.01612137, -0.04198625, -0.04175937,\n",
" -0.03581057, 0.04069342, 0.04197742, -0.00927695, -0.04193225,\n",
" 0.01149309, -0.04092141, 0.04198447, -0.02992041, 0.04198872,\n",
" 0.04116932, -0.04038519, -0.04066702, 0.04136181, -0.01259178,\n",
" -0.04185093, -0.03593153, -0.04197553, -0.02911972, -0.04197792,\n",
" 0.04160346, -0.04164274, -0.04198721, 0.04028285, 0.04197734,\n",
" 0.04134644, -0.04198376, 0.03973573, 0.04166789, -0.03968116,\n",
" -0.04158824, 0.0379812 , -0.04198742, 0.04198844, -0.03801185,\n",
" 0.00107238, -0.03872211, -0.0415947 , -0.01429117, 0.04192708,\n",
" 0.04175217, -0.0407503 , -0.03052849, -0.03178486, -0.04178818,\n",
" -0.00835781, 0.01538996, 0.02737717, 0.02690919, -0.01840801,\n",
" -0.04166178, 0.0363229 , -0.0419296 , -0.04194051, -0.0144698 ,\n",
" 0.04198433, 0.02757781, 0.04198834, -0.0324889 , 0.04198751,\n",
" 0.03295009, -0.03593921, 0.0389869 , -0.03833 , -0.03391436,\n",
" -0.01993591, -0.03885526, 0.03281586, -0.02741082, -0.03237478,\n",
" -0.0415874 , 0.04196396, 0.04198085, 0.04180834, 0.03249368,\n",
" -0.00363372, -0.02481618, 0.01267587, -0.04102535, 0.04187071,\n",
" -0.04181998, -0.04039235, 0.04196578, 0.04196031, 0.04189978,\n",
" 0.03125617, -0.03059685, 0.03569997, -0.03805297, 0.04073007,\n",
" -0.04092512, 0.04164406, -0.0410394 , 0.0304983 , 0.00024938,\n",
" -0.04141904, 0.04198414, 0.02100157, -0.02216624, 0.04191901,\n",
" -0.03446326, 0.04078055, 0.03740211, 0.03694592, 0.04049973,\n",
" -0.01364386, 0.04195723, -0.03984711, -0.04190143, -0.01631859,\n",
" -0.04148455, -0.04166832, -0.04198775, 0.00314687, -0.04039371,\n",
" -0.03810394, -0.01921219, 0.03632106, 0.00925723, -0.03913045,\n",
" 0.02540023, -0.00609886, 0.03419326, -0.01275209, 0.01182831,\n",
" 0.03622852, -0.04188388, -0.02013846, -0.04198548, -0.04091489,\n",
" 0.03813241, 0.04197209, -0.04197228, 0.03769003, -0.04198314,\n",
" -0.03352323, 0.03056168, 0.00524293, -0.02513863, 0.04198688,\n",
" -0.04198511, 0.03171355, 0.04185796, 0.04198896, 0.04123863,\n",
" 0.04191539, -0.03331638, -0.04137757, -0.04195552, -0.04198602,\n",
" -0.04198543, -0.04173329, 0.02361188, 0.04030795, -0.04197871,\n",
" -0.02735724, 0.03485049, 0.04198449, 0.03934003, -0.04191162,\n",
" -0.01556293, -0.04020228, -0.04121891, 0.03989344, -0.03751295,\n",
" -0.04160217, 0.04157483, -0.03747531, 0.04197649, -0.03330726,\n",
" 0.041797 , 0.01553499, 0.03843407, -0.01980506, -0.0419789 ,\n",
" 0.03165124, 0.04198403, 0.03164057, -0.04198309, -0.04103299,\n",
" -0.04036099, -0.04191351, -0.003134 , 0.02560875, 0.04164857,\n",
" -0.04198804, -0.01305335, -0.04181057, 0.01808916, 0.02857054,\n",
" 0.0413773 , 0.04193968, 0.03466176, 0.04067586, 0.04194682,\n",
" 0.00603348, 0.04198304, -0.00531589, -0.04173062, 0.04077149,\n",
" -0.00526975, 0.02699065, -0.04150279, 0.03724448, 0.03875501,\n",
" 0.04198527, 0.04139676, -0.00845909, -0.03940297, -0.01727216,\n",
" 0.04056164, 0.04198492, -0.03924002, -0.03460339, -0.0416716 ,\n",
" -0.04190092, -0.04094217, -0.02820578, -0.03560956, -0.03661607,\n",
" -0.04190323, 0.03896662, 0.03130606, 0.04198624, 0.04198651,\n",
" 0.04182592, -0.03883908, -0.04025847, 0.04177434, 0.03910935,\n",
" 0.04137958, -0.01848272, -0.04198824, -0.04076474, -0.04186779,\n",
" 0.04197471, 0.02308419, -0.02248645, -0.03875297, -0.00690083,\n",
" 0.03050664, -0.04163408, -0.04189519, -0.0419333 , -0.03057267,\n",
" 0.0419875 , -0.0419456 , 0.04154235, -0.0406983 , 0.03434575,\n",
" 0.01125377, -0.02710795, 0.0417644 , -0.00712127, -0.0185032 ,\n",
" -0.03710214, 0.04065449, 0.04168561, 0.04179131, -0.03903219,\n",
" -0.00471208, 0.0388105 , 0.00582988, 0.04196824, 0.02086509,\n",
" 0.03686011, 0.03886231, 0.04197226, 0.00494676, 0.04197789,\n",
" 0.04121486, 0.04198334, 0.04197854, -0.04191182, -0.02971651,\n",
" 0.03768668, -0.00246194, -0.0329401 , 0.02837821, 0.04198884,\n",
" 0.03481174, -0.03618993, -0.04194475, 0.0417482 , 0.04196004,\n",
" 0.04198839, -0.0213686 , 0.04190777, -0.03204934, 0.03499636,\n",
" 0.01937684, 0.01904153, -0.02334654, 0.01754579, 0.0411477 ,\n",
" 0.03963953, -0.04189899, -0.04198864, -0.04198857, 0.04198739,\n",
" 0.04197171, -0.01784322, -0.04198823, 0.04174543, -0.0050752 ,\n",
" 0.04196675, 0.04162002, -0.00382628, -0.03855931, 0.04065217,\n",
" -0.04181691, 0.00271896, 0.0393075 , -0.01726509, 0.04005548,\n",
" 0.03836541, -0.04179063, 0.01329095, 0.04198483, 0.02608233,\n",
" 0.04195726, 0.02435604, -0.04173591, 0.04183252, -0.04187825,\n",
" -0.04193632, -0.04011936, 0.04197941, 0.04191554, -0.02589057,\n",
" 0.03422449, 0.04180543, -0.0419754 , 0.04198517, -0.04198319,\n",
" 0.03433901, -0.04172321, 0.04195996, -0.01611307, -0.04151817,\n",
" 0.00361194, 0.03229348, 0.03992271, -0.04146889, 0.04196749,\n",
" -0.01609168, -0.0397304 , -0.0369703 , -0.04134296, -0.04166955,\n",
" -0.04144209, 0.03486018, -0.04198373, 0.03788313, 0.02914115,\n",
" -0.01495758, -0.03724721, -0.04198475, 0.04194922, -0.00868783,\n",
" -0.03975865, 0.04197279, -0.04183043, -0.0419866 , 0.04186398,\n",
" -0.03901847, -0.00921086, 0.03969617, 0.01368726, -0.02623666,\n",
" -0.0419848 , 0.02551378, 0.04196016, -0.04029222, -0.03572451,\n",
" -0.03911192, -0.03902108, 0.04146091, 0.04026022, -0.00460039,\n",
" 0.02867268, 0.03908598, 0.04178156, 0.00463572, 0.0019693 ,\n",
" 0.03814286, -0.04198531, -0.04182648, -0.00759568, -0.04068376,\n",
" -0.04198061, -0.04197848, 0.04198753, 0.04181282, 0.04198704,\n",
" 0.03055211, -0.02259165, 0.00596491, 0.03945321, -0.0419193 ,\n",
" -0.03403739, 0.0386718 , 0.02885107, -0.03786369, -0.04142422,\n",
" -0.0383199 , -0.04198244, -0.01919667, -0.00147058, -0.041768 ,\n",
" -0.00437526, 0.04198117, 0.04197293, -0.04195101, -0.04140807,\n",
" -0.03768863, -0.04153134, 0.04196649, 0.04143959, 0.03721406,\n",
" -0.0117665 , -0.01828084, 0.03204167, 0.0094318 , 0.01711832,\n",
" -0.04069522, -0.03459236, -0.04198005, 0.02790442, -0.04166448,\n",
" -0.04198093, 0.04142441, 0.0419841 , 0.02891653, -0.04196152,\n",
" 0.0028069 , 0.0419777 , 0.03644475, 0.04198397, -0.00739348,\n",
" 0.04195189, -0.04166092, 0.03832536, -0.04198517, 0.04198223,\n",
" -0.04198345, 0.04197896, 0.04198651, 0.01010275, 0.04168402,\n",
" -0.04182216, 0.04058226, 0.04089009, -0.03151694, 0.03649032,\n",
" -0.03148106, -0.04187498, 0.03986583, 0.04049552, -0.00552932,\n",
" 0.04198367, 0.0415343 , 0.02764859, 0.02312844, 0.01488893,\n",
" 0.04171284, -0.03404972, -0.0401731 , 0.0419626 , 0.04197323,\n",
" 0.04104799, 0.0419841 , 0.03076882, 0.04198139, -0.03973533,\n",
" -0.04182937, 0.04068797, 0.00520835, 0.02324562, -0.04194023,\n",
" 0.04195091, 0.04166578, -0.04197302, -0.02371343, -0.02706317,\n",
" -0.0146492 , 0.04198693, 0.04067883, 0.04181696, 0.04144626,\n",
" -0.01463645, 0.04193835, -0.0393066 , 0.03269231, -0.03796809,\n",
" -0.00075777, 0.04198299, -0.01835105, 0.04174255, -0.04009363,\n",
" 0.04111317, -0.04169864, 0.0417691 , 0.04009948, 0.04139231,\n",
" -0.04127614, 0.04198181, 0.03066452, -0.04176292, -0.04181974,\n",
" -0.04178266, -0.04136662, 0.02656358])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"shixuns_emb[50].reshape(-1)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"def get_cos_similar_matrix(v1, v2):\n",
" #获取两个向量的余弦相似度\n",
" num = np.dot(v1, v2) # 向量点乘\n",
" denom = np.linalg.norm(v1).reshape(-1) * np.linalg.norm(v2).reshape(-1) # 求模长的乘积\n",
" res = num / denom\n",
" res[np.isneginf(res)] = 0.0 #负无穷大的赋值0\n",
" # return num\n",
" return float(0.5 + 0.5 * res)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9740721281077637"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_cos_similar_matrix(shixuns_emb[49], shixuns_emb[50])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mooc",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}