You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
6.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import numpy as np
import matplotlib.pyplot as plt
import cv2
from io import BytesIO
from PIL import Image
import streamlit as st
try:
import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist
HAS_TF = True
except ImportError:
HAS_TF = False
import os
from datetime import datetime
def map_keyword_to_indices(text, class_names):
"""私有方法:从推荐文本解析服装类别索引"""
keyword_map = {
# 中文关键词映射
"t恤": 0, "衬衫": 6, "外套": 4, "套头衫": 2,
"裤子": 1, "牛仔裤": 1,
"连衣裙": 3, "裙子": 3,
"凉鞋": 5, "运动鞋": 7, "短靴": 9, "靴子": 9
}
matched_indices = []
text_lower = text.lower()
for keyword, idx in keyword_map.items():
if keyword in text_lower:
matched_indices.append(idx)
return list(set(matched_indices)) # 去重
# outfits.py 修改后的代码
def generate_sample_outfits(recommendation_text, save_to_local=False, output_dir="C:/Users/Administrator/Desktop"):
"""生成示例穿搭图片(包含连衣裙选项)并可选保存到本地"""
outfits = []
# 创建输出目录
if save_to_local and not os.path.exists(output_dir):
os.makedirs(output_dir)
if HAS_TF:
try:
(train_images, train_labels), _ = fashion_mnist.load_data()
class_names = ['T恤', '裤子', '套头衫', '连衣裙', '外套',
'凉鞋', '衬衫', '运动鞋', '', '短靴']
# 从推荐文本中提取服装类别索引
matched_indices = map_keyword_to_indices(recommendation_text, class_names)
# 根据匹配结果生成图片
if not matched_indices:
st.warning("无法从推荐文本中识别有效服装类别")
return generate_fallback_outfits(save_to_local, output_dir)
# 生成3套穿搭方案
for i in range(3):
try:
# 随机决定是生成连衣裙组合还是上下装组合
use_dress = np.random.choice([True, False]) and (3 in matched_indices)
if use_dress:
# 连衣裙+鞋子组合
shoes_indices = [i for i in matched_indices if i in [5, 7, 9]]
if not shoes_indices:
continue
shoes_idx = np.random.choice(shoes_indices)
dress_img = train_images[np.random.choice(np.where(train_labels == 3)[0])]
shoes_img = train_images[np.random.choice(np.where(train_labels == shoes_idx)[0])]
dress_img = cv2.resize(dress_img, (128, 256), interpolation=cv2.INTER_NEAREST)
shoes_img = cv2.resize(shoes_img, (128, 128), interpolation=cv2.INTER_NEAREST)
combined = np.vstack([dress_img, shoes_img])
combined_rgb = np.stack([combined] * 3, axis=-1)
title = f"{class_names[3]}+{class_names[shoes_idx]}"
else:
# 上下装+鞋子组合
top_indices = [i for i in matched_indices if i in [0, 2, 4, 6]]
bottom_indices = [i for i in matched_indices if i in [1]]
shoes_indices = [i for i in matched_indices if i in [5, 7, 9]]
if not top_indices or not bottom_indices or not shoes_indices:
continue
top_idx = np.random.choice(top_indices)
bottom_idx = np.random.choice(bottom_indices)
shoes_idx = np.random.choice(shoes_indices)
top_img = train_images[np.random.choice(np.where(train_labels == top_idx)[0])]
bottom_img = train_images[np.random.choice(np.where(train_labels == bottom_idx)[0])]
shoes_img = train_images[np.random.choice(np.where(train_labels == shoes_idx)[0])]
top_img = cv2.resize(top_img, (128, 128), interpolation=cv2.INTER_NEAREST)
bottom_img = cv2.resize(bottom_img, (128, 128), interpolation=cv2.INTER_NEAREST)
shoes_img = cv2.resize(shoes_img, (128, 128), interpolation=cv2.INTER_NEAREST)
combined = np.vstack([top_img, bottom_img, shoes_img])
combined_rgb = np.stack([combined] * 3, axis=-1)
title = f"{class_names[top_idx]}+{class_names[bottom_idx]}+{class_names[shoes_idx]}"
# 直接使用PIL保存图片避免plt保存问题
img = Image.fromarray(combined_rgb)
if save_to_local:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{output_dir}/outfit_{timestamp}_{i}.png"
img.save(filename, format='PNG', quality=95)
outfits.append(img)
except Exception as e:
st.warning(f"生成第{i + 1}套穿搭图片时出错: {str(e)}")
outfits.append(generate_single_fallback_outfit(i + 1, save_to_local, output_dir))
continue
return outfits if outfits else generate_fallback_outfits(save_to_local, output_dir)
except Exception as e:
st.warning(f"生成服装图片出错: {str(e)},使用文字替代")
return generate_fallback_outfits(save_to_local, output_dir)
return generate_fallback_outfits(save_to_local, output_dir)
def generate_single_fallback_outfit(index, save_to_local=False, output_dir=None):
"""生成单个回退的文字穿搭方案"""
fig, ax = plt.subplots(figsize=(3, 4))
ax.set_facecolor('#f0f2f6')
ax.text(0.5, 0.5, f"穿搭方案 {index}\n(图片生成失败)",
ha='center', va='center', fontsize=10)
ax.axis('off')
buf = BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1, dpi=100)
buf.seek(0)
img = Image.open(buf)
plt.close(fig)
if save_to_local and output_dir:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{output_dir}/outfit_{timestamp}_{index}_fallback.png"
img.save(filename)
return img
def generate_fallback_outfits(save_to_local=False, output_dir="C:/Users/Administrator/Desktop"):
"""生成回退的文字穿搭方案"""
outfits = []
for i in range(3):
outfits.append(generate_single_fallback_outfit(i + 1, save_to_local, output_dir))
return outfits