You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

87 lines
2.9 KiB

# -*- coding: utf-8 -*-
"""
情感分析推理演示脚本
展示如何使用训练好的模型进行预测
"""
import sys
import os
# 导入主模块中的预测器
sys.path.insert(0, os.path.dirname(__file__))
from 微博情感分析GPT5 import SentimentPredictor, config, logger
def demo_predict():
"""演示预测功能"""
# 测试文本
test_texts = [
"这家餐厅太好吃了,服务也很棒!强烈推荐!",
"质量太差了,完全不值这个价格,很失望",
"还行吧,价格合适,但是没什么特别的",
"今天心情很好,天气不错",
"这个产品用起来很方便,值得购买"
]
# LSTM 模型路径
lstm_model_path = os.path.join(config.model_dir, "best_lstm.keras")
lstm_vocab_path = os.path.join(config.model_dir, "lstm_vocab.pkl")
# 检查模型是否存在
if not os.path.exists(lstm_model_path):
logger.error(f"LSTM 模型文件不存在: {lstm_model_path}")
logger.info("请先运行主程序训练模型")
return
if not os.path.exists(lstm_vocab_path):
logger.error(f"词表文件不存在: {lstm_vocab_path}")
logger.info("请先运行主程序训练模型")
return
# 加载预测器
logger.info("加载 LSTM 预测器...")
predictor = SentimentPredictor(
model_path=lstm_model_path,
vocab_path=lstm_vocab_path,
model_type="lstm"
)
# 批量预测
logger.info("\n开始预测...\n")
print("=" * 80)
print(f"{'文本':<50} {'预测':<10} {'概率':<10}")
print("=" * 80)
for text in test_texts:
label, prob = predictor.predict_text(text)
sentiment = "正面 😊" if label == 1 else "负面 😞"
text_display = text[:45] + "..." if len(text) > 45 else text
print(f"{text_display:<50} {sentiment:<10} {prob:.4f}")
print("=" * 80)
# SVM 模型演示(如果存在)
svm_model_path = os.path.join(config.model_dir, "svm_model.pkl")
if os.path.exists(svm_model_path):
logger.info("\n使用 SVM 模型预测...")
svm_predictor = SentimentPredictor(
model_path=svm_model_path,
vocab_path=svm_model_path, # SVM 的 vocab 在同一文件
model_type="svm"
)
print("\nSVM 预测结果:")
print("=" * 80)
for text in test_texts[:3]: # 只展示前3个
label, prob = svm_predictor.predict_text(text)
sentiment = "正面" if label == 1 else "负面"
print(f"{text[:45]:<50} {sentiment:<10} {prob:.4f}")
print("=" * 80)
if __name__ == "__main__":
try:
demo_predict()
except Exception as e:
logger.exception(f"预测失败: {e}")
sys.exit(1)