import numpy as np import matplotlib.pyplot as plt import cv2 from io import BytesIO from PIL import Image import streamlit as st try: import tensorflow as tf from tensorflow.keras.datasets import fashion_mnist HAS_TF = True except ImportError: HAS_TF = False import os from datetime import datetime def map_keyword_to_indices(text, class_names): """私有方法:从推荐文本解析服装类别索引""" keyword_map = { # 中文关键词映射 "t恤": 0, "衬衫": 6, "外套": 4, "套头衫": 2, "裤子": 1, "牛仔裤": 1, "连衣裙": 3, "裙子": 3, "凉鞋": 5, "运动鞋": 7, "短靴": 9, "靴子": 9 } matched_indices = [] text_lower = text.lower() for keyword, idx in keyword_map.items(): if keyword in text_lower: matched_indices.append(idx) return list(set(matched_indices)) # 去重 # outfits.py 修改后的代码 def generate_sample_outfits(recommendation_text, save_to_local=False, output_dir="C:/Users/Administrator/Desktop"): """生成示例穿搭图片(包含连衣裙选项)并可选保存到本地""" outfits = [] # 创建输出目录 if save_to_local and not os.path.exists(output_dir): os.makedirs(output_dir) if HAS_TF: try: (train_images, train_labels), _ = fashion_mnist.load_data() class_names = ['T恤', '裤子', '套头衫', '连衣裙', '外套', '凉鞋', '衬衫', '运动鞋', '包', '短靴'] # 从推荐文本中提取服装类别索引 matched_indices = map_keyword_to_indices(recommendation_text, class_names) # 根据匹配结果生成图片 if not matched_indices: st.warning("无法从推荐文本中识别有效服装类别") return generate_fallback_outfits(save_to_local, output_dir) # 生成3套穿搭方案 for i in range(3): try: # 随机决定是生成连衣裙组合还是上下装组合 use_dress = np.random.choice([True, False]) and (3 in matched_indices) if use_dress: # 连衣裙+鞋子组合 shoes_indices = [i for i in matched_indices if i in [5, 7, 9]] if not shoes_indices: continue shoes_idx = np.random.choice(shoes_indices) dress_img = train_images[np.random.choice(np.where(train_labels == 3)[0])] shoes_img = train_images[np.random.choice(np.where(train_labels == shoes_idx)[0])] dress_img = cv2.resize(dress_img, (128, 256), interpolation=cv2.INTER_NEAREST) shoes_img = cv2.resize(shoes_img, (128, 128), interpolation=cv2.INTER_NEAREST) combined = np.vstack([dress_img, shoes_img]) combined_rgb = np.stack([combined] * 3, axis=-1) title = f"{class_names[3]}+{class_names[shoes_idx]}" else: # 上下装+鞋子组合 top_indices = [i for i in matched_indices if i in [0, 2, 4, 6]] bottom_indices = [i for i in matched_indices if i in [1]] shoes_indices = [i for i in matched_indices if i in [5, 7, 9]] if not top_indices or not bottom_indices or not shoes_indices: continue top_idx = np.random.choice(top_indices) bottom_idx = np.random.choice(bottom_indices) shoes_idx = np.random.choice(shoes_indices) top_img = train_images[np.random.choice(np.where(train_labels == top_idx)[0])] bottom_img = train_images[np.random.choice(np.where(train_labels == bottom_idx)[0])] shoes_img = train_images[np.random.choice(np.where(train_labels == shoes_idx)[0])] top_img = cv2.resize(top_img, (128, 128), interpolation=cv2.INTER_NEAREST) bottom_img = cv2.resize(bottom_img, (128, 128), interpolation=cv2.INTER_NEAREST) shoes_img = cv2.resize(shoes_img, (128, 128), interpolation=cv2.INTER_NEAREST) combined = np.vstack([top_img, bottom_img, shoes_img]) combined_rgb = np.stack([combined] * 3, axis=-1) title = f"{class_names[top_idx]}+{class_names[bottom_idx]}+{class_names[shoes_idx]}" # 直接使用PIL保存图片,避免plt保存问题 img = Image.fromarray(combined_rgb) if save_to_local: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{output_dir}/outfit_{timestamp}_{i}.png" img.save(filename, format='PNG', quality=95) outfits.append(img) except Exception as e: st.warning(f"生成第{i + 1}套穿搭图片时出错: {str(e)}") outfits.append(generate_single_fallback_outfit(i + 1, save_to_local, output_dir)) continue return outfits if outfits else generate_fallback_outfits(save_to_local, output_dir) except Exception as e: st.warning(f"生成服装图片出错: {str(e)},使用文字替代") return generate_fallback_outfits(save_to_local, output_dir) return generate_fallback_outfits(save_to_local, output_dir) def generate_single_fallback_outfit(index, save_to_local=False, output_dir=None): """生成单个回退的文字穿搭方案""" fig, ax = plt.subplots(figsize=(3, 4)) ax.set_facecolor('#f0f2f6') ax.text(0.5, 0.5, f"穿搭方案 {index}\n(图片生成失败)", ha='center', va='center', fontsize=10) ax.axis('off') buf = BytesIO() fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1, dpi=100) buf.seek(0) img = Image.open(buf) plt.close(fig) if save_to_local and output_dir: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{output_dir}/outfit_{timestamp}_{index}_fallback.png" img.save(filename) return img def generate_fallback_outfits(save_to_local=False, output_dir="C:/Users/Administrator/Desktop"): """生成回退的文字穿搭方案""" outfits = [] for i in range(3): outfits.append(generate_single_fallback_outfit(i + 1, save_to_local, output_dir)) return outfits