You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
2.9 KiB
87 lines
2.9 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
情感分析推理演示脚本
|
|
展示如何使用训练好的模型进行预测
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
|
|
# 导入主模块中的预测器
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
from 微博情感分析GPT5 import SentimentPredictor, config, logger
|
|
|
|
def demo_predict():
|
|
"""演示预测功能"""
|
|
|
|
# 测试文本
|
|
test_texts = [
|
|
"这家餐厅太好吃了,服务也很棒!强烈推荐!",
|
|
"质量太差了,完全不值这个价格,很失望",
|
|
"还行吧,价格合适,但是没什么特别的",
|
|
"今天心情很好,天气不错",
|
|
"这个产品用起来很方便,值得购买"
|
|
]
|
|
|
|
# LSTM 模型路径
|
|
lstm_model_path = os.path.join(config.model_dir, "best_lstm.keras")
|
|
lstm_vocab_path = os.path.join(config.model_dir, "lstm_vocab.pkl")
|
|
|
|
# 检查模型是否存在
|
|
if not os.path.exists(lstm_model_path):
|
|
logger.error(f"LSTM 模型文件不存在: {lstm_model_path}")
|
|
logger.info("请先运行主程序训练模型")
|
|
return
|
|
|
|
if not os.path.exists(lstm_vocab_path):
|
|
logger.error(f"词表文件不存在: {lstm_vocab_path}")
|
|
logger.info("请先运行主程序训练模型")
|
|
return
|
|
|
|
# 加载预测器
|
|
logger.info("加载 LSTM 预测器...")
|
|
predictor = SentimentPredictor(
|
|
model_path=lstm_model_path,
|
|
vocab_path=lstm_vocab_path,
|
|
model_type="lstm"
|
|
)
|
|
|
|
# 批量预测
|
|
logger.info("\n开始预测...\n")
|
|
print("=" * 80)
|
|
print(f"{'文本':<50} {'预测':<10} {'概率':<10}")
|
|
print("=" * 80)
|
|
|
|
for text in test_texts:
|
|
label, prob = predictor.predict_text(text)
|
|
sentiment = "正面 😊" if label == 1 else "负面 😞"
|
|
text_display = text[:45] + "..." if len(text) > 45 else text
|
|
print(f"{text_display:<50} {sentiment:<10} {prob:.4f}")
|
|
|
|
print("=" * 80)
|
|
|
|
# SVM 模型演示(如果存在)
|
|
svm_model_path = os.path.join(config.model_dir, "svm_model.pkl")
|
|
if os.path.exists(svm_model_path):
|
|
logger.info("\n使用 SVM 模型预测...")
|
|
svm_predictor = SentimentPredictor(
|
|
model_path=svm_model_path,
|
|
vocab_path=svm_model_path, # SVM 的 vocab 在同一文件
|
|
model_type="svm"
|
|
)
|
|
|
|
print("\nSVM 预测结果:")
|
|
print("=" * 80)
|
|
for text in test_texts[:3]: # 只展示前3个
|
|
label, prob = svm_predictor.predict_text(text)
|
|
sentiment = "正面" if label == 1 else "负面"
|
|
print(f"{text[:45]:<50} {sentiment:<10} {prob:.4f}")
|
|
print("=" * 80)
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
demo_predict()
|
|
except Exception as e:
|
|
logger.exception(f"预测失败: {e}")
|
|
sys.exit(1)
|