diff --git a/client/ui/__init__.py b/client/ui/__init__.py new file mode 100644 index 0000000..f3e0290 --- /dev/null +++ b/client/ui/__init__.py @@ -0,0 +1,29 @@ +# P2P Network Communication - UI Module +""" +用户界面模块 +提供PyQt6图形界面组件 + +需求: 9.1 +""" + +__version__ = "0.1.0" + +from client.ui.main_window import MainWindow +from client.ui.login_dialog import LoginDialog +from client.ui.chat_widget import ChatWidget +from client.ui.contact_list import ContactListWidget +from client.ui.file_transfer_widget import FileTransferWidget +from client.ui.media_player_widget import MediaPlayerWidget +from client.ui.voice_call_widget import VoiceCallWidget +from client.ui.system_tray import SystemTrayManager + +__all__ = [ + 'MainWindow', + 'LoginDialog', + 'ChatWidget', + 'ContactListWidget', + 'FileTransferWidget', + 'MediaPlayerWidget', + 'VoiceCallWidget', + 'SystemTrayManager', +] diff --git a/client/ui/chat_widget.py b/client/ui/chat_widget.py new file mode 100644 index 0000000..4ddc96e --- /dev/null +++ b/client/ui/chat_widget.py @@ -0,0 +1,426 @@ +# P2P Network Communication - Chat Widget +""" +聊天窗口组件 +实现消息输入、发送、显示和历史加载 + +需求: 3.1, 3.2, 3.3, 3.4, 3.5, 9.2 +""" + +import logging +from datetime import datetime +from typing import Optional, List + +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QTextEdit, QLineEdit, + QPushButton, QLabel, QScrollArea, QFrame, QSizePolicy +) +from PyQt6.QtCore import Qt, pyqtSignal, QTimer +from PyQt6.QtGui import QTextCursor, QKeyEvent + +from shared.models import ChatMessage, MessageType, UserInfo + +logger = logging.getLogger(__name__) + + +class MessageBubble(QFrame): + """消息气泡组件""" + + def __init__(self, message: ChatMessage, is_self: bool = False, parent=None): + super().__init__(parent) + self.message = message + self.is_self = is_self + self._setup_ui() + + def _setup_ui(self) -> None: + """设置UI""" + layout = QVBoxLayout(self) + layout.setContentsMargins(10, 5, 10, 5) + layout.setSpacing(2) + + # 消息内容 + content_label = QLabel(self.message.content) + content_label.setWordWrap(True) + content_label.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) + + # 时间戳 + time_str = self.message.timestamp.strftime("%H:%M") + time_label = QLabel(time_str) + time_label.setStyleSheet("color: #999; font-size: 10px;") + + # 状态指示 + status_text = "✓✓" if self.message.is_read else ("✓" if self.message.is_sent else "⏳") + status_label = QLabel(status_text) + status_label.setStyleSheet("color: #999; font-size: 10px;") + + if self.is_self: + # 自己发送的消息,右对齐 + self.setStyleSheet(""" + QFrame { + background-color: #dcf8c6; + border-radius: 10px; + margin-left: 50px; + } + """) + layout.setAlignment(Qt.AlignmentFlag.AlignRight) + + bottom_layout = QHBoxLayout() + bottom_layout.addStretch() + bottom_layout.addWidget(time_label) + bottom_layout.addWidget(status_label) + + layout.addWidget(content_label) + layout.addLayout(bottom_layout) + else: + # 对方发送的消息,左对齐 + self.setStyleSheet(""" + QFrame { + background-color: white; + border-radius: 10px; + margin-right: 50px; + border: 1px solid #eee; + } + """) + layout.setAlignment(Qt.AlignmentFlag.AlignLeft) + + layout.addWidget(content_label) + layout.addWidget(time_label) + + +class ChatWidget(QWidget): + """ + 聊天窗口组件 + + 实现消息输入和发送 (需求 3.1) + 实现消息显示和历史加载 (需求 3.2, 3.5, 9.2) + 实现消息状态显示 (需求 3.3, 3.4) + """ + + # 信号定义 + message_sent = pyqtSignal(str, str) # peer_id, content + file_send_requested = pyqtSignal(str) # peer_id + image_send_requested = pyqtSignal(str) # peer_id + voice_call_requested = pyqtSignal(str) # peer_id + + def __init__(self, parent=None): + super().__init__(parent) + + self._current_peer_id: Optional[str] = None + self._current_peer_info: Optional[UserInfo] = None + self._messages: List[ChatMessage] = [] + + self._setup_ui() + self._connect_signals() + + logger.info("ChatWidget initialized") + + def _setup_ui(self) -> None: + """设置UI""" + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # 聊天对象信息栏 + self._header = self._create_header() + layout.addWidget(self._header) + + # 消息显示区域 + self._message_area = self._create_message_area() + layout.addWidget(self._message_area, 1) + + # 输入区域 + self._input_area = self._create_input_area() + layout.addWidget(self._input_area) + + def _create_header(self) -> QWidget: + """创建头部信息栏""" + header = QFrame() + header.setFixedHeight(50) + header.setStyleSheet(""" + QFrame { + background-color: #f5f5f5; + border-bottom: 1px solid #ddd; + } + """) + + layout = QHBoxLayout(header) + layout.setContentsMargins(15, 0, 15, 0) + + self._peer_name_label = QLabel("选择联系人开始聊天") + self._peer_name_label.setStyleSheet("font-size: 16px; font-weight: bold;") + layout.addWidget(self._peer_name_label) + + layout.addStretch() + + self._peer_status_label = QLabel("") + self._peer_status_label.setStyleSheet("color: #666;") + layout.addWidget(self._peer_status_label) + + return header + + def _create_message_area(self) -> QWidget: + """创建消息显示区域""" + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + scroll_area.setStyleSheet(""" + QScrollArea { + border: none; + background-color: #e5ddd5; + } + """) + + self._message_container = QWidget() + self._message_layout = QVBoxLayout(self._message_container) + self._message_layout.setContentsMargins(10, 10, 10, 10) + self._message_layout.setSpacing(10) + self._message_layout.addStretch() + + scroll_area.setWidget(self._message_container) + self._scroll_area = scroll_area + + return scroll_area + + def _create_input_area(self) -> QWidget: + """创建输入区域""" + input_frame = QFrame() + input_frame.setStyleSheet(""" + QFrame { + background-color: #f0f0f0; + border-top: 1px solid #ddd; + } + """) + + layout = QVBoxLayout(input_frame) + layout.setContentsMargins(10, 10, 10, 10) + layout.setSpacing(10) + + # 功能按钮行 + button_layout = QHBoxLayout() + button_layout.setSpacing(5) + + self._file_btn = QPushButton("📎") + self._file_btn.setFixedSize(30, 30) + self._file_btn.setToolTip("发送文件") + self._file_btn.clicked.connect(self._on_file_btn_clicked) + button_layout.addWidget(self._file_btn) + + self._image_btn = QPushButton("🖼️") + self._image_btn.setFixedSize(30, 30) + self._image_btn.setToolTip("发送图片") + self._image_btn.clicked.connect(self._on_image_btn_clicked) + button_layout.addWidget(self._image_btn) + + self._voice_btn = QPushButton("📞") + self._voice_btn.setFixedSize(30, 30) + self._voice_btn.setToolTip("语音通话") + self._voice_btn.clicked.connect(self._on_voice_btn_clicked) + button_layout.addWidget(self._voice_btn) + + button_layout.addStretch() + layout.addLayout(button_layout) + + # 输入行 + input_layout = QHBoxLayout() + input_layout.setSpacing(10) + + self._input_edit = QLineEdit() + self._input_edit.setPlaceholderText("输入消息...") + self._input_edit.setMinimumHeight(40) + self._input_edit.setStyleSheet(""" + QLineEdit { + border: 1px solid #ddd; + border-radius: 20px; + padding: 0 15px; + background-color: white; + } + """) + self._input_edit.returnPressed.connect(self._on_send_clicked) + input_layout.addWidget(self._input_edit, 1) + + self._send_btn = QPushButton("发送") + self._send_btn.setMinimumHeight(40) + self._send_btn.setMinimumWidth(80) + self._send_btn.setStyleSheet(""" + QPushButton { + background-color: #4a90d9; + color: white; + border: none; + border-radius: 20px; + font-size: 14px; + } + QPushButton:hover { + background-color: #3a80c9; + } + QPushButton:pressed { + background-color: #2a70b9; + } + QPushButton:disabled { + background-color: #ccc; + } + """) + self._send_btn.clicked.connect(self._on_send_clicked) + input_layout.addWidget(self._send_btn) + + layout.addLayout(input_layout) + + return input_frame + + def _connect_signals(self) -> None: + """连接信号""" + pass + + def _on_send_clicked(self) -> None: + """处理发送按钮点击""" + if not self._current_peer_id: + return + + content = self._input_edit.text().strip() + if not content: + return + + self.message_sent.emit(self._current_peer_id, content) + self._input_edit.clear() + + def _on_file_btn_clicked(self) -> None: + """处理文件按钮点击""" + if self._current_peer_id: + self.file_send_requested.emit(self._current_peer_id) + + def _on_image_btn_clicked(self) -> None: + """处理图片按钮点击""" + if self._current_peer_id: + self.image_send_requested.emit(self._current_peer_id) + + def _on_voice_btn_clicked(self) -> None: + """处理语音按钮点击""" + if self._current_peer_id: + self.voice_call_requested.emit(self._current_peer_id) + + def _scroll_to_bottom(self) -> None: + """滚动到底部""" + QTimer.singleShot(100, lambda: self._scroll_area.verticalScrollBar().setValue( + self._scroll_area.verticalScrollBar().maximum() + )) + + def _clear_messages(self) -> None: + """清空消息显示""" + while self._message_layout.count() > 1: + item = self._message_layout.takeAt(0) + if item.widget(): + item.widget().deleteLater() + + # ==================== 公共方法 ==================== + + def set_peer(self, peer_id: str, peer_info: Optional[UserInfo] = None) -> None: + """ + 设置当前聊天对象 + + 实现用户切换聊天对象 (需求 9.2) + WHEN 用户切换聊天对象 THEN P2P_Client SHALL 加载并显示与该用户的聊天历史 + + Args: + peer_id: 对等端ID + peer_info: 对等端用户信息 + """ + self._current_peer_id = peer_id + self._current_peer_info = peer_info + + if peer_info: + self._peer_name_label.setText(peer_info.display_name or peer_info.username) + self._peer_status_label.setText(peer_info.status.value) + else: + self._peer_name_label.setText(peer_id) + self._peer_status_label.setText("") + + # 清空并加载历史消息 + self._clear_messages() + self._messages.clear() + + # 启用输入 + self._input_edit.setEnabled(True) + self._send_btn.setEnabled(True) + self._input_edit.setFocus() + + logger.info(f"Chat peer set: {peer_id}") + + def add_message(self, message: ChatMessage, is_self: bool = False) -> None: + """ + 添加消息到显示区域 + + 实现消息显示 (需求 3.2) + WHEN P2P_Client 收到文本消息 THEN P2P_Client SHALL 立即显示消息内容和发送者信息 + + Args: + message: 聊天消息 + is_self: 是否是自己发送的消息 + """ + self._messages.append(message) + + bubble = MessageBubble(message, is_self) + # 在stretch之前插入 + self._message_layout.insertWidget(self._message_layout.count() - 1, bubble) + + self._scroll_to_bottom() + + def load_history(self, messages: List[ChatMessage], current_user_id: str) -> None: + """ + 加载聊天历史 + + 实现消息历史加载 (需求 3.5) + WHEN 显示消息历史 THEN P2P_Client SHALL 按时间顺序展示所有消息记录 + + Args: + messages: 消息列表(应按时间排序) + current_user_id: 当前用户ID + """ + self._clear_messages() + self._messages = messages.copy() + + for msg in messages: + is_self = msg.sender_id == current_user_id + bubble = MessageBubble(msg, is_self) + self._message_layout.insertWidget(self._message_layout.count() - 1, bubble) + + self._scroll_to_bottom() + logger.info(f"Loaded {len(messages)} messages") + + def update_message_status(self, message_id: str, is_sent: bool = False, is_read: bool = False) -> None: + """ + 更新消息状态 + + 实现消息状态显示 (需求 3.3, 3.4) + WHEN 消息发送成功 THEN P2P_Client SHALL 显示发送成功的状态标识 + IF 消息发送失败 THEN P2P_Client SHALL 显示错误提示并提供重试选项 + + Args: + message_id: 消息ID + is_sent: 是否已发送 + is_read: 是否已读 + """ + for msg in self._messages: + if msg.message_id == message_id: + msg.is_sent = is_sent + msg.is_read = is_read + break + + # 刷新显示(简化实现,实际应只更新对应的bubble) + if self._current_peer_id and self._messages: + current_user_id = self._messages[0].sender_id if self._messages else "" + self.load_history(self._messages, current_user_id) + + def clear(self) -> None: + """清空聊天窗口""" + self._current_peer_id = None + self._current_peer_info = None + self._clear_messages() + self._messages.clear() + + self._peer_name_label.setText("选择联系人开始聊天") + self._peer_status_label.setText("") + self._input_edit.setEnabled(False) + self._send_btn.setEnabled(False) + + @property + def current_peer_id(self) -> Optional[str]: + """获取当前聊天对象ID""" + return self._current_peer_id diff --git a/client/ui/contact_list.py b/client/ui/contact_list.py new file mode 100644 index 0000000..54e62f5 --- /dev/null +++ b/client/ui/contact_list.py @@ -0,0 +1,331 @@ +# P2P Network Communication - Contact List Widget +""" +联系人列表组件 +显示在线用户列表和联系人管理 + +需求: 9.1, 2.3, 2.4 +""" + +import logging +from typing import Optional, Dict, List + +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QListWidget, QListWidgetItem, + QLabel, QLineEdit, QPushButton, QMenu, QFrame +) +from PyQt6.QtCore import Qt, pyqtSignal, QSize +from PyQt6.QtGui import QIcon, QColor, QAction + +from shared.models import UserInfo, UserStatus + +logger = logging.getLogger(__name__) + + +class ContactItem(QListWidgetItem): + """联系人列表项""" + + def __init__(self, user_info: UserInfo, parent: Optional[QListWidget] = None): + super().__init__(parent) + self.user_info = user_info + self._update_display() + + def _update_display(self) -> None: + """更新显示""" + status_icons = { + UserStatus.ONLINE: "🟢", + UserStatus.OFFLINE: "⚫", + UserStatus.BUSY: "🔴", + UserStatus.AWAY: "🟡", + } + + icon = status_icons.get(self.user_info.status, "⚫") + display_name = self.user_info.display_name or self.user_info.username + self.setText(f"{icon} {display_name}") + + # 设置提示信息 + self.setToolTip( + f"用户名: {self.user_info.username}\n" + f"状态: {self.user_info.status.value}\n" + f"ID: {self.user_info.user_id}" + ) + + def update_status(self, status: UserStatus) -> None: + """更新用户状态""" + self.user_info.status = status + self._update_display() + + +class ContactListWidget(QWidget): + """ + 联系人列表组件 + + 实现联系人列表面板 (需求 9.1) + WHEN 用户请求查看在线用户列表 THEN P2P_Client SHALL 显示当前所有在线用户 (需求 2.3) + WHEN 用户选择一个在线用户 THEN P2P_Client SHALL 显示该用户的基本信息和连接状态 (需求 2.4) + """ + + # 信号定义 + contact_selected = pyqtSignal(str) # 选中联系人时发出,参数为user_id + contact_double_clicked = pyqtSignal(str) # 双击联系人时发出 + refresh_requested = pyqtSignal() # 请求刷新联系人列表 + + def __init__(self, parent: Optional[QWidget] = None): + super().__init__(parent) + + self._contacts: Dict[str, ContactItem] = {} + + self._setup_ui() + self._connect_signals() + + logger.info("ContactListWidget initialized") + + def _setup_ui(self) -> None: + """设置UI""" + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # 搜索栏 + search_frame = QFrame() + search_frame.setStyleSheet(""" + QFrame { + background-color: #f5f5f5; + border-bottom: 1px solid #ddd; + } + """) + search_layout = QHBoxLayout(search_frame) + search_layout.setContentsMargins(8, 8, 8, 8) + + self._search_input = QLineEdit() + self._search_input.setPlaceholderText("搜索联系人...") + self._search_input.setClearButtonEnabled(True) + self._search_input.textChanged.connect(self._filter_contacts) + search_layout.addWidget(self._search_input) + + self._refresh_btn = QPushButton("🔄") + self._refresh_btn.setFixedSize(30, 30) + self._refresh_btn.setToolTip("刷新联系人列表") + self._refresh_btn.clicked.connect(self._on_refresh_clicked) + search_layout.addWidget(self._refresh_btn) + + layout.addWidget(search_frame) + + # 联系人列表 + self._list_widget = QListWidget() + self._list_widget.setStyleSheet(""" + QListWidget { + border: none; + background-color: white; + } + QListWidget::item { + padding: 10px; + border-bottom: 1px solid #eee; + } + QListWidget::item:selected { + background-color: #e3f2fd; + color: black; + } + QListWidget::item:hover { + background-color: #f5f5f5; + } + """) + self._list_widget.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self._list_widget.customContextMenuRequested.connect(self._show_context_menu) + layout.addWidget(self._list_widget, 1) + + # 底部状态栏 + status_frame = QFrame() + status_frame.setStyleSheet(""" + QFrame { + background-color: #f5f5f5; + border-top: 1px solid #ddd; + } + """) + status_layout = QHBoxLayout(status_frame) + status_layout.setContentsMargins(8, 4, 8, 4) + + self._status_label = QLabel("0 位联系人在线") + self._status_label.setStyleSheet("color: #666; font-size: 12px;") + status_layout.addWidget(self._status_label) + + layout.addWidget(status_frame) + + def _connect_signals(self) -> None: + """连接信号""" + self._list_widget.itemClicked.connect(self._on_item_clicked) + self._list_widget.itemDoubleClicked.connect(self._on_item_double_clicked) + + def _on_item_clicked(self, item: QListWidgetItem) -> None: + """处理项目点击""" + if isinstance(item, ContactItem): + self.contact_selected.emit(item.user_info.user_id) + + def _on_item_double_clicked(self, item: QListWidgetItem) -> None: + """处理项目双击""" + if isinstance(item, ContactItem): + self.contact_double_clicked.emit(item.user_info.user_id) + + def _on_refresh_clicked(self) -> None: + """处理刷新按钮点击""" + self.refresh_requested.emit() + + def _filter_contacts(self, text: str) -> None: + """过滤联系人列表""" + text = text.lower() + for i in range(self._list_widget.count()): + item = self._list_widget.item(i) + if isinstance(item, ContactItem): + visible = ( + text in item.user_info.username.lower() or + text in (item.user_info.display_name or "").lower() + ) + item.setHidden(not visible) + + def _show_context_menu(self, pos) -> None: + """显示右键菜单""" + item = self._list_widget.itemAt(pos) + if not isinstance(item, ContactItem): + return + + menu = QMenu(self) + + chat_action = QAction("发送消息", self) + chat_action.triggered.connect(lambda: self.contact_double_clicked.emit(item.user_info.user_id)) + menu.addAction(chat_action) + + menu.addSeparator() + + info_action = QAction("查看资料", self) + info_action.triggered.connect(lambda: self._show_user_info(item.user_info)) + menu.addAction(info_action) + + menu.exec(self._list_widget.mapToGlobal(pos)) + + def _show_user_info(self, user_info: UserInfo) -> None: + """显示用户信息""" + from PyQt6.QtWidgets import QMessageBox + + QMessageBox.information( + self, + "用户信息", + f"用户名: {user_info.username}\n" + f"显示名: {user_info.display_name or '未设置'}\n" + f"状态: {user_info.status.value}\n" + f"ID: {user_info.user_id}" + ) + + def _update_status_label(self) -> None: + """更新状态标签""" + online_count = sum( + 1 for item in self._contacts.values() + if item.user_info.status == UserStatus.ONLINE + ) + total_count = len(self._contacts) + self._status_label.setText(f"{online_count}/{total_count} 位联系人在线") + + # ==================== 公共方法 ==================== + + def add_contact(self, user_info: UserInfo) -> None: + """ + 添加联系人 + + Args: + user_info: 用户信息 + """ + if user_info.user_id in self._contacts: + # 更新现有联系人 + self.update_contact_status(user_info.user_id, user_info.status) + return + + item = ContactItem(user_info) + self._list_widget.addItem(item) + self._contacts[user_info.user_id] = item + self._update_status_label() + + logger.debug(f"Contact added: {user_info.username}") + + def remove_contact(self, user_id: str) -> None: + """ + 移除联系人 + + Args: + user_id: 用户ID + """ + if user_id not in self._contacts: + return + + item = self._contacts.pop(user_id) + row = self._list_widget.row(item) + self._list_widget.takeItem(row) + self._update_status_label() + + logger.debug(f"Contact removed: {user_id}") + + def update_contact_status(self, user_id: str, status: UserStatus) -> None: + """ + 更新联系人状态 + + Args: + user_id: 用户ID + status: 新状态 + """ + if user_id not in self._contacts: + return + + self._contacts[user_id].update_status(status) + self._update_status_label() + + logger.debug(f"Contact status updated: {user_id} -> {status.value}") + + def set_contacts(self, users: List[UserInfo]) -> None: + """ + 设置联系人列表 + + Args: + users: 用户信息列表 + """ + self.clear_contacts() + for user in users: + self.add_contact(user) + + def clear_contacts(self) -> None: + """清空联系人列表""" + self._list_widget.clear() + self._contacts.clear() + self._update_status_label() + + def get_selected_contact(self) -> Optional[UserInfo]: + """ + 获取当前选中的联系人 + + Returns: + 选中的用户信息,如果没有选中则返回None + """ + item = self._list_widget.currentItem() + if isinstance(item, ContactItem): + return item.user_info + return None + + def get_contact(self, user_id: str) -> Optional[UserInfo]: + """ + 获取指定联系人信息 + + Args: + user_id: 用户ID + + Returns: + 用户信息 + """ + if user_id in self._contacts: + return self._contacts[user_id].user_info + return None + + def select_contact(self, user_id: str) -> None: + """ + 选中指定联系人 + + Args: + user_id: 用户ID + """ + if user_id in self._contacts: + self._list_widget.setCurrentItem(self._contacts[user_id]) diff --git a/client/ui/file_transfer_widget.py b/client/ui/file_transfer_widget.py new file mode 100644 index 0000000..7ea6b7b --- /dev/null +++ b/client/ui/file_transfer_widget.py @@ -0,0 +1,673 @@ +# P2P Network Communication - File Transfer Widget +""" +文件传输界面组件 +实现文件选择、传输进度显示和图片预览 + +需求: 4.1, 4.3, 4.6, 5.2, 5.3, 5.4, 5.6 +""" + +import logging +import os +from typing import Optional, Dict + +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, + QProgressBar, QFileDialog, QScrollArea, QFrame, QMessageBox, + QDialog, QSlider, QSizePolicy +) +from PyQt6.QtCore import Qt, pyqtSignal, QSize +from PyQt6.QtGui import QPixmap, QTransform + +from shared.models import TransferProgress, TransferStatus + +logger = logging.getLogger(__name__) + + +class TransferItem(QFrame): + """传输项组件""" + + cancel_requested = pyqtSignal(str) # file_id + + def __init__(self, file_id: str, file_name: str, total_size: int, + is_upload: bool = True, parent=None): + super().__init__(parent) + self.file_id = file_id + self.file_name = file_name + self.total_size = total_size + self.is_upload = is_upload + self._setup_ui() + + def _setup_ui(self) -> None: + """设置UI""" + self.setFrameShape(QFrame.Shape.StyledPanel) + self.setStyleSheet(""" + QFrame { + background-color: white; + border: 1px solid #ddd; + border-radius: 5px; + margin: 2px; + } + """) + + layout = QVBoxLayout(self) + layout.setContentsMargins(10, 10, 10, 10) + layout.setSpacing(5) + + # 文件信息行 + info_layout = QHBoxLayout() + + icon = "⬆️" if self.is_upload else "⬇️" + self._name_label = QLabel(f"{icon} {self.file_name}") + self._name_label.setStyleSheet("font-weight: bold;") + info_layout.addWidget(self._name_label, 1) + + self._size_label = QLabel(self._format_size(self.total_size)) + self._size_label.setStyleSheet("color: #666;") + info_layout.addWidget(self._size_label) + + self._cancel_btn = QPushButton("✕") + self._cancel_btn.setFixedSize(20, 20) + self._cancel_btn.setStyleSheet(""" + QPushButton { + border: none; + color: #999; + } + QPushButton:hover { + color: red; + } + """) + self._cancel_btn.clicked.connect(lambda: self.cancel_requested.emit(self.file_id)) + info_layout.addWidget(self._cancel_btn) + + layout.addLayout(info_layout) + + # 进度条 + self._progress_bar = QProgressBar() + self._progress_bar.setRange(0, 100) + self._progress_bar.setValue(0) + self._progress_bar.setTextVisible(True) + self._progress_bar.setStyleSheet(""" + QProgressBar { + border: 1px solid #ddd; + border-radius: 3px; + text-align: center; + height: 20px; + } + QProgressBar::chunk { + background-color: #4a90d9; + border-radius: 2px; + } + """) + layout.addWidget(self._progress_bar) + + # 状态行 + status_layout = QHBoxLayout() + + self._status_label = QLabel("等待中...") + self._status_label.setStyleSheet("color: #666; font-size: 12px;") + status_layout.addWidget(self._status_label) + + status_layout.addStretch() + + self._speed_label = QLabel("") + self._speed_label.setStyleSheet("color: #666; font-size: 12px;") + status_layout.addWidget(self._speed_label) + + layout.addLayout(status_layout) + + def _format_size(self, size: int) -> str: + """格式化文件大小""" + if size < 1024: + return f"{size} B" + elif size < 1024 * 1024: + return f"{size / 1024:.1f} KB" + elif size < 1024 * 1024 * 1024: + return f"{size / (1024 * 1024):.1f} MB" + else: + return f"{size / (1024 * 1024 * 1024):.2f} GB" + + def _format_speed(self, speed: float) -> str: + """格式化传输速度""" + return f"{self._format_size(int(speed))}/s" + + def update_progress(self, progress: TransferProgress) -> None: + """更新进度""" + percent = int(progress.progress_percent) + self._progress_bar.setValue(percent) + self._speed_label.setText(self._format_speed(progress.speed)) + + if progress.eta > 0: + eta_str = f"剩余 {int(progress.eta)}秒" + else: + eta_str = "" + self._status_label.setText(f"传输中... {eta_str}") + + def set_status(self, status: TransferStatus) -> None: + """设置状态""" + status_texts = { + TransferStatus.PENDING: "等待中...", + TransferStatus.IN_PROGRESS: "传输中...", + TransferStatus.COMPLETED: "✓ 完成", + TransferStatus.FAILED: "✕ 失败", + TransferStatus.CANCELLED: "已取消", + TransferStatus.PAUSED: "已暂停", + } + self._status_label.setText(status_texts.get(status, "未知")) + + if status == TransferStatus.COMPLETED: + self._progress_bar.setValue(100) + self._cancel_btn.hide() + self._progress_bar.setStyleSheet(""" + QProgressBar { + border: 1px solid #ddd; + border-radius: 3px; + text-align: center; + height: 20px; + } + QProgressBar::chunk { + background-color: #4caf50; + border-radius: 2px; + } + """) + elif status == TransferStatus.FAILED: + self._cancel_btn.hide() + self._progress_bar.setStyleSheet(""" + QProgressBar { + border: 1px solid #ddd; + border-radius: 3px; + text-align: center; + height: 20px; + } + QProgressBar::chunk { + background-color: #f44336; + border-radius: 2px; + } + """) + + +class FileTransferWidget(QWidget): + """ + 文件传输界面组件 + + 实现文件选择对话框 (需求 4.1) + 实现传输进度显示 (需求 4.3) + 实现图片预览和查看 (需求 5.2, 5.3, 5.4, 5.6) + """ + + # 信号定义 + send_file_requested = pyqtSignal(str, str) # peer_id, file_path + cancel_transfer_requested = pyqtSignal(str) # file_id + + def __init__(self, parent=None): + super().__init__(parent) + + self._transfers: Dict[str, TransferItem] = {} + + self._setup_ui() + + logger.info("FileTransferWidget initialized") + + def _setup_ui(self) -> None: + """设置UI""" + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # 标题栏 + header = QFrame() + header.setFixedHeight(40) + header.setStyleSheet(""" + QFrame { + background-color: #f5f5f5; + border-bottom: 1px solid #ddd; + } + """) + header_layout = QHBoxLayout(header) + header_layout.setContentsMargins(10, 0, 10, 0) + + title_label = QLabel("文件传输") + title_label.setStyleSheet("font-weight: bold;") + header_layout.addWidget(title_label) + + header_layout.addStretch() + + self._clear_btn = QPushButton("清除已完成") + self._clear_btn.clicked.connect(self._clear_completed) + header_layout.addWidget(self._clear_btn) + + layout.addWidget(header) + + # 传输列表 + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + + self._transfer_container = QWidget() + self._transfer_layout = QVBoxLayout(self._transfer_container) + self._transfer_layout.setContentsMargins(5, 5, 5, 5) + self._transfer_layout.setSpacing(5) + self._transfer_layout.addStretch() + + scroll_area.setWidget(self._transfer_container) + layout.addWidget(scroll_area, 1) + + # 空状态提示 + self._empty_label = QLabel("暂无传输任务") + self._empty_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._empty_label.setStyleSheet("color: #999;") + self._transfer_layout.insertWidget(0, self._empty_label) + + def _clear_completed(self) -> None: + """清除已完成的传输""" + to_remove = [] + for file_id, item in self._transfers.items(): + # 检查是否已完成(简化判断) + if item._progress_bar.value() == 100: + to_remove.append(file_id) + + for file_id in to_remove: + self.remove_transfer(file_id) + + def _update_empty_state(self) -> None: + """更新空状态显示""" + self._empty_label.setVisible(len(self._transfers) == 0) + + # ==================== 公共方法 ==================== + + def add_transfer(self, file_id: str, file_name: str, total_size: int, + is_upload: bool = True) -> None: + """ + 添加传输任务 + + Args: + file_id: 文件ID + file_name: 文件名 + total_size: 文件大小 + is_upload: 是否是上传 + """ + if file_id in self._transfers: + return + + item = TransferItem(file_id, file_name, total_size, is_upload) + item.cancel_requested.connect(self._on_cancel_requested) + + self._transfer_layout.insertWidget(self._transfer_layout.count() - 1, item) + self._transfers[file_id] = item + self._update_empty_state() + + logger.debug(f"Transfer added: {file_name}") + + def update_transfer_progress(self, file_id: str, progress: TransferProgress) -> None: + """ + 更新传输进度 + + 实现传输进度显示 (需求 4.3) + WHILE 文件传输进行中 THEN P2P_Client SHALL 显示传输进度百分比和传输速度 + + Args: + file_id: 文件ID + progress: 进度信息 + """ + if file_id in self._transfers: + self._transfers[file_id].update_progress(progress) + + def update_transfer_status(self, file_id: str, status: TransferStatus) -> None: + """ + 更新传输状态 + + Args: + file_id: 文件ID + status: 状态 + """ + if file_id in self._transfers: + self._transfers[file_id].set_status(status) + + def remove_transfer(self, file_id: str) -> None: + """ + 移除传输任务 + + Args: + file_id: 文件ID + """ + if file_id not in self._transfers: + return + + item = self._transfers.pop(file_id) + self._transfer_layout.removeWidget(item) + item.deleteLater() + self._update_empty_state() + + def _on_cancel_requested(self, file_id: str) -> None: + """处理取消请求""" + self.cancel_transfer_requested.emit(file_id) + + def show_file_dialog(self, peer_id: str) -> Optional[str]: + """ + 显示文件选择对话框 + + 实现文件选择对话框 (需求 4.1) + WHEN 用户选择发送文件 THEN File_Transfer_Module SHALL 允许用户从本地选择任意类型的文件 + + Args: + peer_id: 目标用户ID + + Returns: + 选择的文件路径,如果取消则返回None + """ + file_path, _ = QFileDialog.getOpenFileName( + self, + "选择文件", + "", + "所有文件 (*.*)" + ) + + if file_path: + self.send_file_requested.emit(peer_id, file_path) + return file_path + return None + + def show_image_dialog(self, peer_id: str) -> Optional[str]: + """ + 显示图片选择对话框 + + Args: + peer_id: 目标用户ID + + Returns: + 选择的图片路径 + """ + file_path, _ = QFileDialog.getOpenFileName( + self, + "选择图片", + "", + "图片文件 (*.jpg *.jpeg *.png *.gif *.bmp);;所有文件 (*.*)" + ) + + if file_path: + self.send_file_requested.emit(peer_id, file_path) + return file_path + return None + + def show_save_dialog(self, file_name: str) -> Optional[str]: + """ + 显示保存文件对话框 + + 实现用户选择保存位置 (需求 4.6) + WHEN 接收到文件 THEN P2P_Client SHALL 允许用户选择保存位置 + + Args: + file_name: 默认文件名 + + Returns: + 保存路径 + """ + file_path, _ = QFileDialog.getSaveFileName( + self, + "保存文件", + file_name, + "所有文件 (*.*)" + ) + return file_path if file_path else None + + def show_image_preview(self, image_path: str) -> None: + """ + 显示图片预览 + + 实现图片预览和查看 (需求 5.4) + WHEN 用户点击图片缩略图 THEN P2P_Client SHALL 打开图片查看器显示完整图片 + + Args: + image_path: 图片路径 + """ + dialog = ImagePreviewDialog(image_path, self) + dialog.exec() + + def create_image_thumbnail(self, image_path: str) -> Optional[ImageThumbnail]: + """ + 创建图片缩略图组件 + + 实现图片缩略图预览 (需求 5.2, 5.3) + + Args: + image_path: 图片路径 + + Returns: + 缩略图组件 + """ + if not os.path.exists(image_path): + return None + + thumbnail = ImageThumbnail(image_path) + thumbnail.clicked.connect(self.show_image_preview) + return thumbnail + + + +class ImagePreviewDialog(QDialog): + """ + 图片预览对话框 + + 实现图片预览和查看 (需求 5.3, 5.4) + WHEN 图片传输完成 THEN P2P_Client SHALL 在聊天窗口内直接显示图片缩略图 + WHEN 用户点击图片缩略图 THEN P2P_Client SHALL 打开图片查看器显示完整图片 + """ + + def __init__(self, image_path: str, parent=None): + super().__init__(parent) + self.image_path = image_path + self._rotation = 0 + self._zoom_level = 100 + self._setup_ui() + self._load_image() + + def _setup_ui(self) -> None: + """设置UI""" + self.setWindowTitle("图片预览") + self.setMinimumSize(600, 500) + self.resize(800, 600) + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # 图片显示区域 + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setAlignment(Qt.AlignmentFlag.AlignCenter) + scroll_area.setStyleSheet("background-color: #2d2d2d;") + + self._image_label = QLabel() + self._image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + scroll_area.setWidget(self._image_label) + + layout.addWidget(scroll_area, 1) + + # 工具栏 + toolbar = QFrame() + toolbar.setFixedHeight(50) + toolbar.setStyleSheet(""" + QFrame { + background-color: #f5f5f5; + border-top: 1px solid #ddd; + } + """) + + toolbar_layout = QHBoxLayout(toolbar) + toolbar_layout.setContentsMargins(10, 5, 10, 5) + toolbar_layout.setSpacing(10) + + # 缩放控制 + zoom_out_btn = QPushButton("➖") + zoom_out_btn.setFixedSize(30, 30) + zoom_out_btn.clicked.connect(self._zoom_out) + toolbar_layout.addWidget(zoom_out_btn) + + self._zoom_slider = QSlider(Qt.Orientation.Horizontal) + self._zoom_slider.setRange(25, 400) + self._zoom_slider.setValue(100) + self._zoom_slider.setFixedWidth(150) + self._zoom_slider.valueChanged.connect(self._on_zoom_changed) + toolbar_layout.addWidget(self._zoom_slider) + + zoom_in_btn = QPushButton("➕") + zoom_in_btn.setFixedSize(30, 30) + zoom_in_btn.clicked.connect(self._zoom_in) + toolbar_layout.addWidget(zoom_in_btn) + + self._zoom_label = QLabel("100%") + self._zoom_label.setFixedWidth(50) + toolbar_layout.addWidget(self._zoom_label) + + toolbar_layout.addStretch() + + # 旋转按钮 + rotate_left_btn = QPushButton("↺") + rotate_left_btn.setFixedSize(30, 30) + rotate_left_btn.setToolTip("向左旋转") + rotate_left_btn.clicked.connect(self._rotate_left) + toolbar_layout.addWidget(rotate_left_btn) + + rotate_right_btn = QPushButton("↻") + rotate_right_btn.setFixedSize(30, 30) + rotate_right_btn.setToolTip("向右旋转") + rotate_right_btn.clicked.connect(self._rotate_right) + toolbar_layout.addWidget(rotate_right_btn) + + toolbar_layout.addStretch() + + # 关闭按钮 + close_btn = QPushButton("关闭") + close_btn.clicked.connect(self.close) + toolbar_layout.addWidget(close_btn) + + layout.addWidget(toolbar) + + def _load_image(self) -> None: + """加载图片""" + if not os.path.exists(self.image_path): + self._image_label.setText("图片不存在") + return + + self._original_pixmap = QPixmap(self.image_path) + if self._original_pixmap.isNull(): + self._image_label.setText("无法加载图片") + return + + self._update_display() + + def _update_display(self) -> None: + """更新图片显示""" + if not hasattr(self, '_original_pixmap') or self._original_pixmap.isNull(): + return + + # 应用旋转 + transform = QTransform() + transform.rotate(self._rotation) + rotated = self._original_pixmap.transformed(transform, Qt.TransformationMode.SmoothTransformation) + + # 应用缩放 + scale = self._zoom_level / 100.0 + scaled = rotated.scaled( + int(rotated.width() * scale), + int(rotated.height() * scale), + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation + ) + + self._image_label.setPixmap(scaled) + + def _zoom_in(self) -> None: + """放大""" + self._zoom_slider.setValue(min(400, self._zoom_level + 25)) + + def _zoom_out(self) -> None: + """缩小""" + self._zoom_slider.setValue(max(25, self._zoom_level - 25)) + + def _on_zoom_changed(self, value: int) -> None: + """缩放变化""" + self._zoom_level = value + self._zoom_label.setText(f"{value}%") + self._update_display() + + def _rotate_left(self) -> None: + """向左旋转""" + self._rotation = (self._rotation - 90) % 360 + self._update_display() + + def _rotate_right(self) -> None: + """向右旋转""" + self._rotation = (self._rotation + 90) % 360 + self._update_display() + + +class ImageThumbnail(QFrame): + """ + 图片缩略图组件 + + 实现图片缩略图预览 (需求 5.2) + WHEN 图片发送前 THEN P2P_Client SHALL 显示图片缩略图预览 + """ + + clicked = pyqtSignal(str) # image_path + + def __init__(self, image_path: str, thumbnail_size: tuple = (150, 150), parent=None): + super().__init__(parent) + self.image_path = image_path + self.thumbnail_size = thumbnail_size + self._setup_ui() + + def _setup_ui(self) -> None: + """设置UI""" + self.setFrameShape(QFrame.Shape.StyledPanel) + self.setCursor(Qt.CursorShape.PointingHandCursor) + self.setStyleSheet(""" + QFrame { + background-color: white; + border: 1px solid #ddd; + border-radius: 5px; + } + QFrame:hover { + border-color: #4a90d9; + } + """) + + layout = QVBoxLayout(self) + layout.setContentsMargins(5, 5, 5, 5) + layout.setSpacing(5) + + # 缩略图 + self._image_label = QLabel() + self._image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._image_label.setFixedSize(*self.thumbnail_size) + + if os.path.exists(self.image_path): + pixmap = QPixmap(self.image_path) + if not pixmap.isNull(): + scaled = pixmap.scaled( + *self.thumbnail_size, + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation + ) + self._image_label.setPixmap(scaled) + else: + self._image_label.setText("无法加载") + else: + self._image_label.setText("图片不存在") + + layout.addWidget(self._image_label) + + # 文件名 + file_name = os.path.basename(self.image_path) + if len(file_name) > 20: + file_name = file_name[:17] + "..." + name_label = QLabel(file_name) + name_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + name_label.setStyleSheet("font-size: 11px; color: #666;") + layout.addWidget(name_label) + + def mousePressEvent(self, event) -> None: + """鼠标点击事件""" + if event.button() == Qt.MouseButton.LeftButton: + self.clicked.emit(self.image_path) + super().mousePressEvent(event) diff --git a/client/ui/login_dialog.py b/client/ui/login_dialog.py new file mode 100644 index 0000000..db8c47a --- /dev/null +++ b/client/ui/login_dialog.py @@ -0,0 +1,304 @@ +# P2P Network Communication - Login Dialog +""" +登录对话框 +实现用户注册和登录界面 + +需求: 2.1 +""" + +import logging +import uuid +import re +from typing import Optional + +from PyQt6.QtWidgets import ( + QDialog, QVBoxLayout, QHBoxLayout, QLabel, QLineEdit, + QPushButton, QMessageBox, QFormLayout, QFrame, QTabWidget, + QWidget, QCheckBox +) +from PyQt6.QtCore import Qt, pyqtSignal + +from shared.models import UserInfo, UserStatus + +logger = logging.getLogger(__name__) + + +class LoginDialog(QDialog): + """ + 登录对话框 + + 实现用户名设置界面和登录/注册流程 (需求 2.1) + WHEN 用户首次使用应用程序 THEN P2P_Client SHALL 允许用户设置唯一的用户名和显示名称 + """ + + # 信号定义 + login_successful = pyqtSignal(UserInfo) + + def __init__(self, parent=None): + super().__init__(parent) + + self._user_info: Optional[UserInfo] = None + + self._setup_ui() + + logger.info("LoginDialog initialized") + + def _setup_ui(self) -> None: + """设置UI""" + self.setWindowTitle("登录 / 注册") + self.setFixedSize(450, 400) + self.setModal(True) + + layout = QVBoxLayout(self) + layout.setSpacing(15) + layout.setContentsMargins(20, 20, 20, 20) + + # 标题 + title_label = QLabel("P2P 通信应用") + title_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + title_label.setStyleSheet("font-size: 28px; font-weight: bold; color: #4a90d9;") + layout.addWidget(title_label) + + subtitle_label = QLabel("安全、快速的点对点通信") + subtitle_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + subtitle_label.setStyleSheet("font-size: 12px; color: #666;") + layout.addWidget(subtitle_label) + + layout.addSpacing(10) + + # 选项卡 + self._tab_widget = QTabWidget() + self._tab_widget.addTab(self._create_login_tab(), "登录") + self._tab_widget.addTab(self._create_register_tab(), "注册") + layout.addWidget(self._tab_widget) + + # 设置焦点 + self._login_username_input.setFocus() + + def _create_login_tab(self) -> QWidget: + """创建登录选项卡""" + tab = QWidget() + layout = QVBoxLayout(tab) + layout.setSpacing(15) + layout.setContentsMargins(20, 20, 20, 20) + + # 表单 + form_layout = QFormLayout() + form_layout.setSpacing(12) + + self._login_username_input = QLineEdit() + self._login_username_input.setPlaceholderText("请输入用户名") + self._login_username_input.setMinimumHeight(38) + self._login_username_input.setStyleSheet(self._get_input_style()) + form_layout.addRow("用户名:", self._login_username_input) + + self._login_server_input = QLineEdit() + self._login_server_input.setPlaceholderText("服务器地址:端口") + self._login_server_input.setText("127.0.0.1:8888") + self._login_server_input.setMinimumHeight(38) + self._login_server_input.setStyleSheet(self._get_input_style()) + form_layout.addRow("服务器:", self._login_server_input) + + layout.addLayout(form_layout) + + # 记住设置 + self._remember_checkbox = QCheckBox("记住我的设置") + layout.addWidget(self._remember_checkbox) + + layout.addStretch() + + # 登录按钮 + self._login_btn = QPushButton("登录") + self._login_btn.setMinimumHeight(45) + self._login_btn.setStyleSheet(self._get_button_style()) + self._login_btn.clicked.connect(self._on_login) + layout.addWidget(self._login_btn) + + # 快捷键 + self._login_username_input.returnPressed.connect(self._on_login) + + return tab + + def _create_register_tab(self) -> QWidget: + """创建注册选项卡""" + tab = QWidget() + layout = QVBoxLayout(tab) + layout.setSpacing(15) + layout.setContentsMargins(20, 20, 20, 20) + + # 表单 + form_layout = QFormLayout() + form_layout.setSpacing(12) + + self._reg_username_input = QLineEdit() + self._reg_username_input.setPlaceholderText("3-20个字符,字母数字下划线") + self._reg_username_input.setMinimumHeight(38) + self._reg_username_input.setStyleSheet(self._get_input_style()) + form_layout.addRow("用户名:", self._reg_username_input) + + self._reg_display_name_input = QLineEdit() + self._reg_display_name_input.setPlaceholderText("显示给其他用户的名称(可选)") + self._reg_display_name_input.setMinimumHeight(38) + self._reg_display_name_input.setStyleSheet(self._get_input_style()) + form_layout.addRow("显示名:", self._reg_display_name_input) + + self._reg_server_input = QLineEdit() + self._reg_server_input.setPlaceholderText("服务器地址:端口") + self._reg_server_input.setText("127.0.0.1:8888") + self._reg_server_input.setMinimumHeight(38) + self._reg_server_input.setStyleSheet(self._get_input_style()) + form_layout.addRow("服务器:", self._reg_server_input) + + layout.addLayout(form_layout) + + # 提示信息 + hint_label = QLabel("用户名将作为您的唯一标识,注册后不可更改") + hint_label.setStyleSheet("color: #999; font-size: 11px;") + layout.addWidget(hint_label) + + layout.addStretch() + + # 注册按钮 + self._register_btn = QPushButton("注册并登录") + self._register_btn.setMinimumHeight(45) + self._register_btn.setStyleSheet(self._get_button_style()) + self._register_btn.clicked.connect(self._on_register) + layout.addWidget(self._register_btn) + + return tab + + def _get_input_style(self) -> str: + """获取输入框样式""" + return """ + QLineEdit { + border: 1px solid #ddd; + border-radius: 5px; + padding: 0 12px; + background-color: #fafafa; + } + QLineEdit:focus { + border-color: #4a90d9; + background-color: white; + } + """ + + def _get_button_style(self) -> str: + """获取按钮样式""" + return """ + QPushButton { + background-color: #4a90d9; + color: white; + border: none; + border-radius: 5px; + font-size: 15px; + font-weight: bold; + } + QPushButton:hover { + background-color: #3a80c9; + } + QPushButton:pressed { + background-color: #2a70b9; + } + QPushButton:disabled { + background-color: #ccc; + } + """ + + def _validate_username(self, username: str) -> tuple: + """ + 验证用户名 + + Args: + username: 用户名 + + Returns: + (is_valid, error_message) + """ + if not username: + return False, "请输入用户名" + + if len(username) < 3: + return False, "用户名至少需要3个字符" + + if len(username) > 20: + return False, "用户名不能超过20个字符" + + if not re.match(r'^[a-zA-Z0-9_]+$', username): + return False, "用户名只能包含字母、数字和下划线" + + return True, "" + + def _on_login(self) -> None: + """处理登录""" + username = self._login_username_input.text().strip() + + is_valid, error_msg = self._validate_username(username) + if not is_valid: + QMessageBox.warning(self, "错误", error_msg) + self._login_username_input.setFocus() + return + + # 创建用户信息 + self._user_info = UserInfo( + user_id=str(uuid.uuid4()), + username=username, + display_name=username, + status=UserStatus.ONLINE + ) + + logger.info(f"User logged in: {username}") + self.login_successful.emit(self._user_info) + self.accept() + + def _on_register(self) -> None: + """处理注册""" + username = self._reg_username_input.text().strip() + display_name = self._reg_display_name_input.text().strip() + + is_valid, error_msg = self._validate_username(username) + if not is_valid: + QMessageBox.warning(self, "错误", error_msg) + self._reg_username_input.setFocus() + return + + # 创建用户信息 + self._user_info = UserInfo( + user_id=str(uuid.uuid4()), + username=username, + display_name=display_name or username, + status=UserStatus.ONLINE + ) + + logger.info(f"User registered and logged in: {username}") + self.login_successful.emit(self._user_info) + self.accept() + + def get_user_info(self) -> Optional[UserInfo]: + """ + 获取用户信息 + + Returns: + 用户信息,如果未登录则返回None + """ + return self._user_info + + def get_server_address(self) -> tuple: + """ + 获取服务器地址 + + Returns: + (host, port) 元组 + """ + # 根据当前选项卡获取服务器地址 + if self._tab_widget.currentIndex() == 0: + server = self._login_server_input.text().strip() + else: + server = self._reg_server_input.text().strip() + + if ":" in server: + host, port = server.split(":", 1) + try: + return host, int(port) + except ValueError: + return host, 8888 + return server, 8888 diff --git a/client/ui/main_window.py b/client/ui/main_window.py new file mode 100644 index 0000000..c7a294d --- /dev/null +++ b/client/ui/main_window.py @@ -0,0 +1,550 @@ +# P2P Network Communication - Main Window +""" +主窗口模块 +实现应用程序主界面,包含联系人列表、聊天窗口和功能按钮 + +需求: 9.1 +""" + +import logging +from typing import Optional, Dict, Any + +from PyQt6.QtWidgets import ( + QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, + QSplitter, QStatusBar, QToolBar, QMenuBar, + QMenu, QMessageBox, QLabel, QFrame +) +from PyQt6.QtCore import Qt, pyqtSignal, QSize +from PyQt6.QtGui import QAction, QIcon, QCloseEvent + +from config import ClientConfig +from shared.models import UserInfo, UserStatus, Message, MessageType + +logger = logging.getLogger(__name__) + + +class MainWindow(QMainWindow): + """ + 主窗口 + + 实现应用程序主界面 (需求 9.1) + WHEN P2P_Client 启动 THEN P2P_Client SHALL 显示主界面包含联系人列表、聊天窗口和功能按钮 + """ + + # 信号定义 + message_received = pyqtSignal(Message) + user_status_changed = pyqtSignal(str, UserStatus) + connection_state_changed = pyqtSignal(str) + + def __init__(self, config: Optional[ClientConfig] = None, parent: Optional[QWidget] = None): + """ + 初始化主窗口 + + Args: + config: 客户端配置 + parent: 父窗口 + """ + super().__init__(parent) + + self.config = config or ClientConfig() + self._current_user: Optional[UserInfo] = None + self._current_chat_peer: Optional[str] = None + + # 子组件引用(延迟导入避免循环依赖) + self._contact_list: Optional[QWidget] = None + self._chat_widget: Optional[QWidget] = None + self._file_transfer_widget: Optional[QWidget] = None + self._media_player_widget: Optional[QWidget] = None + self._voice_call_widget: Optional[QWidget] = None + self._system_tray: Optional[Any] = None + + self._setup_ui() + self._setup_menu_bar() + self._setup_tool_bar() + self._setup_status_bar() + self._connect_signals() + + logger.info("MainWindow initialized") + + def _setup_ui(self) -> None: + """设置UI布局""" + self.setWindowTitle("P2P 通信应用") + self.setMinimumSize(800, 600) + self.resize(self.config.window_width, self.config.window_height) + + # 创建中央部件 + central_widget = QWidget() + self.setCentralWidget(central_widget) + + # 主布局 + main_layout = QHBoxLayout(central_widget) + main_layout.setContentsMargins(0, 0, 0, 0) + main_layout.setSpacing(0) + + # 创建分割器 + self._splitter = QSplitter(Qt.Orientation.Horizontal) + main_layout.addWidget(self._splitter) + + # 左侧面板:联系人列表 + self._left_panel = self._create_left_panel() + self._splitter.addWidget(self._left_panel) + + # 右侧面板:聊天窗口和功能区 + self._right_panel = self._create_right_panel() + self._splitter.addWidget(self._right_panel) + + # 设置分割比例 + self._splitter.setSizes([250, 750]) + self._splitter.setStretchFactor(0, 0) + self._splitter.setStretchFactor(1, 1) + + def _create_left_panel(self) -> QWidget: + """ + 创建左侧面板(联系人列表) + + 实现联系人列表面板 (需求 9.1) + """ + panel = QFrame() + panel.setFrameShape(QFrame.Shape.StyledPanel) + panel.setMinimumWidth(200) + panel.setMaximumWidth(400) + + layout = QVBoxLayout(panel) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # 用户信息区域 + self._user_info_widget = self._create_user_info_widget() + layout.addWidget(self._user_info_widget) + + # 联系人列表占位符(将由ContactListWidget替换) + self._contact_list_placeholder = QLabel("联系人列表") + self._contact_list_placeholder.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._contact_list_placeholder.setStyleSheet(""" + QLabel { + background-color: #f5f5f5; + color: #666; + padding: 20px; + } + """) + layout.addWidget(self._contact_list_placeholder, 1) + + return panel + + def _create_user_info_widget(self) -> QWidget: + """创建用户信息显示区域""" + widget = QFrame() + widget.setFrameShape(QFrame.Shape.StyledPanel) + widget.setStyleSheet(""" + QFrame { + background-color: #4a90d9; + border: none; + border-bottom: 1px solid #3a80c9; + } + """) + + layout = QHBoxLayout(widget) + layout.setContentsMargins(10, 10, 10, 10) + + # 用户头像占位 + self._avatar_label = QLabel("👤") + self._avatar_label.setFixedSize(40, 40) + self._avatar_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._avatar_label.setStyleSheet(""" + QLabel { + background-color: white; + border-radius: 20px; + font-size: 20px; + } + """) + layout.addWidget(self._avatar_label) + + # 用户名和状态 + info_layout = QVBoxLayout() + info_layout.setSpacing(2) + + self._username_label = QLabel("未登录") + self._username_label.setStyleSheet("color: white; font-weight: bold; font-size: 14px;") + info_layout.addWidget(self._username_label) + + self._status_label = QLabel("离线") + self._status_label.setStyleSheet("color: #ddd; font-size: 12px;") + info_layout.addWidget(self._status_label) + + layout.addLayout(info_layout, 1) + + return widget + + def _create_right_panel(self) -> QWidget: + """ + 创建右侧面板(聊天窗口和功能区) + + 实现聊天窗口面板和功能按钮区域 (需求 9.1) + """ + panel = QFrame() + panel.setFrameShape(QFrame.Shape.StyledPanel) + + layout = QVBoxLayout(panel) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # 聊天窗口占位符(将由ChatWidget替换) + self._chat_placeholder = QLabel("选择联系人开始聊天") + self._chat_placeholder.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._chat_placeholder.setStyleSheet(""" + QLabel { + background-color: #fafafa; + color: #999; + font-size: 16px; + } + """) + layout.addWidget(self._chat_placeholder, 1) + + # 功能按钮区域 + self._function_bar = self._create_function_bar() + layout.addWidget(self._function_bar) + + return panel + + def _create_function_bar(self) -> QWidget: + """ + 创建功能按钮区域 + + 实现功能按钮区域 (需求 9.1) + """ + bar = QFrame() + bar.setFrameShape(QFrame.Shape.StyledPanel) + bar.setFixedHeight(50) + bar.setStyleSheet(""" + QFrame { + background-color: #f0f0f0; + border-top: 1px solid #ddd; + } + """) + + layout = QHBoxLayout(bar) + layout.setContentsMargins(10, 5, 10, 5) + layout.setSpacing(10) + + # 功能按钮将在后续子任务中添加 + layout.addStretch() + + return bar + + def _setup_menu_bar(self) -> None: + """设置菜单栏""" + menubar = self.menuBar() + + # 文件菜单 + file_menu = menubar.addMenu("文件(&F)") + + self._login_action = QAction("登录(&L)", self) + self._login_action.setShortcut("Ctrl+L") + self._login_action.triggered.connect(self._on_login) + file_menu.addAction(self._login_action) + + self._logout_action = QAction("注销(&O)", self) + self._logout_action.setEnabled(False) + self._logout_action.triggered.connect(self._on_logout) + file_menu.addAction(self._logout_action) + + file_menu.addSeparator() + + self._exit_action = QAction("退出(&X)", self) + self._exit_action.setShortcut("Ctrl+Q") + self._exit_action.triggered.connect(self.close) + file_menu.addAction(self._exit_action) + + # 聊天菜单 + chat_menu = menubar.addMenu("聊天(&C)") + + self._send_file_action = QAction("发送文件(&F)", self) + self._send_file_action.setShortcut("Ctrl+Shift+F") + self._send_file_action.setEnabled(False) + chat_menu.addAction(self._send_file_action) + + self._send_image_action = QAction("发送图片(&I)", self) + self._send_image_action.setShortcut("Ctrl+Shift+I") + self._send_image_action.setEnabled(False) + chat_menu.addAction(self._send_image_action) + + chat_menu.addSeparator() + + self._voice_call_action = QAction("语音通话(&V)", self) + self._voice_call_action.setShortcut("Ctrl+Shift+V") + self._voice_call_action.setEnabled(False) + chat_menu.addAction(self._voice_call_action) + + # 视图菜单 + view_menu = menubar.addMenu("视图(&V)") + + self._show_contacts_action = QAction("显示联系人(&C)", self) + self._show_contacts_action.setCheckable(True) + self._show_contacts_action.setChecked(True) + self._show_contacts_action.triggered.connect(self._toggle_contacts_panel) + view_menu.addAction(self._show_contacts_action) + + # 帮助菜单 + help_menu = menubar.addMenu("帮助(&H)") + + self._about_action = QAction("关于(&A)", self) + self._about_action.triggered.connect(self._show_about) + help_menu.addAction(self._about_action) + + def _setup_tool_bar(self) -> None: + """设置工具栏""" + toolbar = QToolBar("主工具栏") + toolbar.setMovable(False) + toolbar.setIconSize(QSize(24, 24)) + self.addToolBar(toolbar) + + # 工具栏按钮将在后续添加图标后完善 + toolbar.addAction(self._login_action) + toolbar.addSeparator() + toolbar.addAction(self._send_file_action) + toolbar.addAction(self._send_image_action) + toolbar.addAction(self._voice_call_action) + + def _setup_status_bar(self) -> None: + """设置状态栏""" + self._statusbar = QStatusBar() + self.setStatusBar(self._statusbar) + + # 连接状态标签 + self._connection_status_label = QLabel("未连接") + self._connection_status_label.setStyleSheet("color: #999;") + self._statusbar.addPermanentWidget(self._connection_status_label) + + self._statusbar.showMessage("就绪") + + def _connect_signals(self) -> None: + """连接信号和槽""" + self.message_received.connect(self._on_message_received) + self.user_status_changed.connect(self._on_user_status_changed) + self.connection_state_changed.connect(self._on_connection_state_changed) + + # ==================== 事件处理 ==================== + + def closeEvent(self, event: QCloseEvent) -> None: + """ + 窗口关闭事件 + + 实现后台运行 (需求 9.5) + WHEN 应用程序最小化 THEN P2P_Client SHALL 保持后台运行并继续接收消息 + """ + # 如果有系统托盘,最小化到托盘而不是关闭 + if self._system_tray and self._system_tray.is_available(): + event.ignore() + self.hide() + self._system_tray.show_message("P2P通信", "应用程序已最小化到系统托盘") + else: + # 确认退出 + reply = QMessageBox.question( + self, + "确认退出", + "确定要退出应用程序吗?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No + ) + + if reply == QMessageBox.StandardButton.Yes: + self._cleanup() + event.accept() + else: + event.ignore() + + def resizeEvent(self, event) -> None: + """ + 窗口大小调整事件 + + 实现自适应调整界面布局 (需求 9.6) + WHEN 用户调整窗口大小 THEN P2P_Client SHALL 自适应调整界面布局 + """ + super().resizeEvent(event) + # 布局会自动调整,无需额外处理 + + # ==================== 槽函数 ==================== + + def _on_login(self) -> None: + """处理登录操作""" + from client.ui.login_dialog import LoginDialog + + dialog = LoginDialog(self) + if dialog.exec(): + user_info = dialog.get_user_info() + if user_info: + self.set_current_user(user_info) + self._statusbar.showMessage(f"已登录: {user_info.username}") + + def _on_logout(self) -> None: + """处理注销操作""" + reply = QMessageBox.question( + self, + "确认注销", + "确定要注销当前账户吗?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No + ) + + if reply == QMessageBox.StandardButton.Yes: + self._current_user = None + self._update_user_display() + self._login_action.setEnabled(True) + self._logout_action.setEnabled(False) + self._statusbar.showMessage("已注销") + + def _toggle_contacts_panel(self, checked: bool) -> None: + """切换联系人面板显示""" + self._left_panel.setVisible(checked) + + def _show_about(self) -> None: + """显示关于对话框""" + QMessageBox.about( + self, + "关于 P2P 通信应用", + "P2P 网络通信应用程序\n\n" + "版本: 0.1.0\n\n" + "支持功能:\n" + "- 文本消息通信\n" + "- 文件传输\n" + "- 图片传输与显示\n" + "- 音视频播放\n" + "- 语音聊天" + ) + + def _on_message_received(self, message: Message) -> None: + """处理接收到的消息""" + logger.debug(f"Message received: {message.msg_type.value}") + # 消息处理将在ChatWidget中实现 + + def _on_user_status_changed(self, user_id: str, status: UserStatus) -> None: + """处理用户状态变化""" + logger.debug(f"User {user_id} status changed to {status.value}") + # 状态更新将在ContactListWidget中实现 + + def _on_connection_state_changed(self, state: str) -> None: + """处理连接状态变化""" + self._connection_status_label.setText(state) + + if state == "已连接": + self._connection_status_label.setStyleSheet("color: green;") + elif state == "连接中...": + self._connection_status_label.setStyleSheet("color: orange;") + else: + self._connection_status_label.setStyleSheet("color: red;") + + # ==================== 公共方法 ==================== + + def set_current_user(self, user_info: UserInfo) -> None: + """ + 设置当前用户 + + Args: + user_info: 用户信息 + """ + self._current_user = user_info + self._update_user_display() + self._login_action.setEnabled(False) + self._logout_action.setEnabled(True) + + logger.info(f"Current user set: {user_info.username}") + + def _update_user_display(self) -> None: + """更新用户显示""" + if self._current_user: + self._username_label.setText(self._current_user.display_name or self._current_user.username) + self._status_label.setText(self._current_user.status.value) + else: + self._username_label.setText("未登录") + self._status_label.setText("离线") + + def set_contact_list_widget(self, widget: QWidget) -> None: + """ + 设置联系人列表组件 + + Args: + widget: 联系人列表组件 + """ + # 移除占位符 + layout = self._left_panel.layout() + layout.removeWidget(self._contact_list_placeholder) + self._contact_list_placeholder.deleteLater() + + # 添加实际组件 + self._contact_list = widget + layout.addWidget(widget, 1) + + def set_chat_widget(self, widget: QWidget) -> None: + """ + 设置聊天组件 + + Args: + widget: 聊天组件 + """ + # 移除占位符 + layout = self._right_panel.layout() + layout.removeWidget(self._chat_placeholder) + self._chat_placeholder.deleteLater() + + # 添加实际组件(在功能栏之前) + self._chat_widget = widget + layout.insertWidget(0, widget, 1) + + def set_system_tray(self, tray_manager) -> None: + """ + 设置系统托盘管理器 + + Args: + tray_manager: 系统托盘管理器 + """ + self._system_tray = tray_manager + + def show_notification(self, title: str, message: str) -> None: + """ + 显示通知 + + 实现新消息通知 (需求 9.3) + WHEN 有新消息到达 THEN P2P_Client SHALL 在系统托盘显示通知提醒 + + Args: + title: 通知标题 + message: 通知内容 + """ + if self._system_tray: + self._system_tray.show_message(title, message) + + def enable_chat_actions(self, enabled: bool = True) -> None: + """ + 启用/禁用聊天相关操作 + + Args: + enabled: 是否启用 + """ + self._send_file_action.setEnabled(enabled) + self._send_image_action.setEnabled(enabled) + self._voice_call_action.setEnabled(enabled) + + def set_current_chat_peer(self, peer_id: Optional[str]) -> None: + """ + 设置当前聊天对象 + + Args: + peer_id: 对等端ID + """ + self._current_chat_peer = peer_id + self.enable_chat_actions(peer_id is not None) + + def _cleanup(self) -> None: + """清理资源""" + logger.info("MainWindow cleanup") + # 清理将在集成时实现 + + @property + def current_user(self) -> Optional[UserInfo]: + """获取当前用户""" + return self._current_user + + @property + def current_chat_peer(self) -> Optional[str]: + """获取当前聊天对象""" + return self._current_chat_peer diff --git a/client/ui/media_player_widget.py b/client/ui/media_player_widget.py new file mode 100644 index 0000000..11d728f --- /dev/null +++ b/client/ui/media_player_widget.py @@ -0,0 +1,335 @@ +# P2P Network Communication - Media Player Widget +""" +媒体播放器界面组件 +实现音频和视频播放器界面 + +需求: 6.3, 6.4, 6.5, 6.7 +""" + +import logging +from typing import Optional + +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, + QSlider, QFrame, QStackedWidget, QSizePolicy +) +from PyQt6.QtCore import Qt, pyqtSignal, QTimer + +from client.media_player import PlaybackState + +logger = logging.getLogger(__name__) + + +class MediaPlayerWidget(QWidget): + """ + 媒体播放器界面组件 + + 实现音频播放器界面 (需求 6.3) + 实现视频播放器界面 (需求 6.4) + 实现播放控制按钮 (需求 6.5) + 实现全屏模式 (需求 6.7) + """ + + # 信号定义 + play_requested = pyqtSignal() + pause_requested = pyqtSignal() + stop_requested = pyqtSignal() + seek_requested = pyqtSignal(float) # position in seconds + volume_changed = pyqtSignal(float) # volume 0.0-1.0 + fullscreen_requested = pyqtSignal(bool) + + def __init__(self, parent=None): + super().__init__(parent) + + self._is_video = False + self._duration = 0.0 + self._is_seeking = False + + self._setup_ui() + self._connect_signals() + + # 更新定时器 + self._update_timer = QTimer() + self._update_timer.setInterval(100) + self._update_timer.timeout.connect(self._on_update_timer) + + logger.info("MediaPlayerWidget initialized") + + def _setup_ui(self) -> None: + """设置UI""" + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # 视频显示区域 + self._video_frame = QFrame() + self._video_frame.setStyleSheet("background-color: black;") + self._video_frame.setMinimumHeight(200) + self._video_frame.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + + video_layout = QVBoxLayout(self._video_frame) + video_layout.setContentsMargins(0, 0, 0, 0) + + self._video_label = QLabel("无媒体") + self._video_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._video_label.setStyleSheet("color: #666; font-size: 16px;") + video_layout.addWidget(self._video_label) + + layout.addWidget(self._video_frame, 1) + + # 控制栏 + self._control_bar = self._create_control_bar() + layout.addWidget(self._control_bar) + + def _create_control_bar(self) -> QWidget: + """创建控制栏""" + bar = QFrame() + bar.setFixedHeight(80) + bar.setStyleSheet(""" + QFrame { + background-color: #2d2d2d; + } + """) + + layout = QVBoxLayout(bar) + layout.setContentsMargins(10, 5, 10, 5) + layout.setSpacing(5) + + # 进度条 + progress_layout = QHBoxLayout() + progress_layout.setSpacing(10) + + self._current_time_label = QLabel("00:00") + self._current_time_label.setStyleSheet("color: white; font-size: 12px;") + self._current_time_label.setFixedWidth(50) + progress_layout.addWidget(self._current_time_label) + + self._progress_slider = QSlider(Qt.Orientation.Horizontal) + self._progress_slider.setRange(0, 1000) + self._progress_slider.setValue(0) + self._progress_slider.setStyleSheet(""" + QSlider::groove:horizontal { + border: none; + height: 4px; + background: #555; + border-radius: 2px; + } + QSlider::handle:horizontal { + background: white; + width: 12px; + height: 12px; + margin: -4px 0; + border-radius: 6px; + } + QSlider::sub-page:horizontal { + background: #4a90d9; + border-radius: 2px; + } + """) + progress_layout.addWidget(self._progress_slider, 1) + + self._total_time_label = QLabel("00:00") + self._total_time_label.setStyleSheet("color: white; font-size: 12px;") + self._total_time_label.setFixedWidth(50) + progress_layout.addWidget(self._total_time_label) + + layout.addLayout(progress_layout) + + # 控制按钮 + button_layout = QHBoxLayout() + button_layout.setSpacing(10) + + button_style = """ + QPushButton { + background-color: transparent; + color: white; + border: none; + font-size: 18px; + padding: 5px 10px; + } + QPushButton:hover { + background-color: #444; + border-radius: 5px; + } + """ + + self._play_btn = QPushButton("▶") + self._play_btn.setStyleSheet(button_style) + self._play_btn.setFixedSize(40, 40) + self._play_btn.clicked.connect(self._on_play_clicked) + button_layout.addWidget(self._play_btn) + + self._stop_btn = QPushButton("⏹") + self._stop_btn.setStyleSheet(button_style) + self._stop_btn.setFixedSize(40, 40) + self._stop_btn.clicked.connect(self._on_stop_clicked) + button_layout.addWidget(self._stop_btn) + + button_layout.addStretch() + + # 音量控制 + self._volume_btn = QPushButton("🔊") + self._volume_btn.setStyleSheet(button_style) + self._volume_btn.setFixedSize(40, 40) + button_layout.addWidget(self._volume_btn) + + self._volume_slider = QSlider(Qt.Orientation.Horizontal) + self._volume_slider.setRange(0, 100) + self._volume_slider.setValue(100) + self._volume_slider.setFixedWidth(100) + self._volume_slider.setStyleSheet(""" + QSlider::groove:horizontal { + border: none; + height: 4px; + background: #555; + border-radius: 2px; + } + QSlider::handle:horizontal { + background: white; + width: 10px; + height: 10px; + margin: -3px 0; + border-radius: 5px; + } + QSlider::sub-page:horizontal { + background: #4a90d9; + border-radius: 2px; + } + """) + button_layout.addWidget(self._volume_slider) + + # 全屏按钮 + self._fullscreen_btn = QPushButton("⛶") + self._fullscreen_btn.setStyleSheet(button_style) + self._fullscreen_btn.setFixedSize(40, 40) + self._fullscreen_btn.clicked.connect(self._on_fullscreen_clicked) + button_layout.addWidget(self._fullscreen_btn) + + layout.addLayout(button_layout) + + return bar + + def _connect_signals(self) -> None: + """连接信号""" + self._progress_slider.sliderPressed.connect(self._on_slider_pressed) + self._progress_slider.sliderReleased.connect(self._on_slider_released) + self._volume_slider.valueChanged.connect(self._on_volume_changed) + + def _format_time(self, seconds: float) -> str: + """格式化时间""" + minutes = int(seconds // 60) + secs = int(seconds % 60) + return f"{minutes:02d}:{secs:02d}" + + def _on_play_clicked(self) -> None: + """处理播放/暂停按钮点击""" + if self._play_btn.text() == "▶": + self.play_requested.emit() + else: + self.pause_requested.emit() + + def _on_stop_clicked(self) -> None: + """处理停止按钮点击""" + self.stop_requested.emit() + + def _on_fullscreen_clicked(self) -> None: + """处理全屏按钮点击""" + self.fullscreen_requested.emit(True) + + def _on_slider_pressed(self) -> None: + """进度条按下""" + self._is_seeking = True + + def _on_slider_released(self) -> None: + """进度条释放""" + self._is_seeking = False + if self._duration > 0: + position = (self._progress_slider.value() / 1000) * self._duration + self.seek_requested.emit(position) + + def _on_volume_changed(self, value: int) -> None: + """音量变化""" + volume = value / 100.0 + self.volume_changed.emit(volume) + + if value == 0: + self._volume_btn.setText("🔇") + elif value < 50: + self._volume_btn.setText("🔉") + else: + self._volume_btn.setText("🔊") + + def _on_update_timer(self) -> None: + """更新定时器回调""" + # 由外部调用update_position更新 + pass + + # ==================== 公共方法 ==================== + + def set_media_info(self, file_name: str, duration: float, is_video: bool = False) -> None: + """ + 设置媒体信息 + + Args: + file_name: 文件名 + duration: 时长(秒) + is_video: 是否是视频 + """ + self._duration = duration + self._is_video = is_video + + self._video_label.setText(file_name) + self._total_time_label.setText(self._format_time(duration)) + self._progress_slider.setValue(0) + self._current_time_label.setText("00:00") + + def update_position(self, position: float) -> None: + """ + 更新播放位置 + + Args: + position: 当前位置(秒) + """ + if not self._is_seeking and self._duration > 0: + slider_value = int((position / self._duration) * 1000) + self._progress_slider.setValue(slider_value) + + self._current_time_label.setText(self._format_time(position)) + + def update_state(self, state: PlaybackState) -> None: + """ + 更新播放状态 + + Args: + state: 播放状态 + """ + if state == PlaybackState.PLAYING: + self._play_btn.setText("⏸") + self._update_timer.start() + elif state == PlaybackState.PAUSED: + self._play_btn.setText("▶") + self._update_timer.stop() + elif state == PlaybackState.STOPPED: + self._play_btn.setText("▶") + self._update_timer.stop() + self._progress_slider.setValue(0) + self._current_time_label.setText("00:00") + + def set_volume(self, volume: float) -> None: + """ + 设置音量 + + Args: + volume: 音量 (0.0-1.0) + """ + self._volume_slider.setValue(int(volume * 100)) + + def clear(self) -> None: + """清空播放器""" + self._duration = 0.0 + self._video_label.setText("无媒体") + self._total_time_label.setText("00:00") + self._current_time_label.setText("00:00") + self._progress_slider.setValue(0) + self._play_btn.setText("▶") + self._update_timer.stop() diff --git a/client/ui/system_tray.py b/client/ui/system_tray.py new file mode 100644 index 0000000..7dfdf86 --- /dev/null +++ b/client/ui/system_tray.py @@ -0,0 +1,179 @@ +# P2P Network Communication - System Tray Manager +""" +系统托盘管理器 +实现系统托盘图标、新消息通知和后台运行 + +需求: 9.3, 9.5 +""" + +import logging +from typing import Optional, Callable + +from PyQt6.QtWidgets import QSystemTrayIcon, QMenu, QApplication +from PyQt6.QtGui import QIcon, QAction +from PyQt6.QtCore import pyqtSignal, QObject + +logger = logging.getLogger(__name__) + + +class SystemTrayManager(QObject): + """ + 系统托盘管理器 + + 实现系统托盘图标 (需求 9.3) + 实现新消息通知 (需求 9.3) + 实现后台运行 (需求 9.5) + """ + + # 信号定义 + show_window_requested = pyqtSignal() + quit_requested = pyqtSignal() + + def __init__(self, parent=None): + super().__init__(parent) + + self._tray_icon: Optional[QSystemTrayIcon] = None + self._menu: Optional[QMenu] = None + self._unread_count: int = 0 + + self._setup_tray() + + logger.info("SystemTrayManager initialized") + + def _setup_tray(self) -> None: + """设置系统托盘""" + if not QSystemTrayIcon.isSystemTrayAvailable(): + logger.warning("System tray is not available") + return + + self._tray_icon = QSystemTrayIcon(self.parent()) + + # 设置默认图标(使用应用程序图标或创建简单图标) + app = QApplication.instance() + if app and not app.windowIcon().isNull(): + self._tray_icon.setIcon(app.windowIcon()) + else: + # 创建一个简单的默认图标 + from PyQt6.QtGui import QPixmap, QPainter, QColor + pixmap = QPixmap(32, 32) + pixmap.fill(QColor("#4a90d9")) + painter = QPainter(pixmap) + painter.setPen(QColor("white")) + painter.drawText(pixmap.rect(), 0x0084, "P2P") # AlignCenter + painter.end() + self._tray_icon.setIcon(QIcon(pixmap)) + + self._tray_icon.setToolTip("P2P 通信应用") + + # 创建菜单 + self._menu = QMenu() + + show_action = QAction("显示主窗口", self._menu) + show_action.triggered.connect(self._on_show_window) + self._menu.addAction(show_action) + + self._menu.addSeparator() + + self._status_action = QAction("状态: 在线", self._menu) + self._status_action.setEnabled(False) + self._menu.addAction(self._status_action) + + self._menu.addSeparator() + + quit_action = QAction("退出", self._menu) + quit_action.triggered.connect(self._on_quit) + self._menu.addAction(quit_action) + + self._tray_icon.setContextMenu(self._menu) + + # 连接信号 + self._tray_icon.activated.connect(self._on_tray_activated) + + def _on_show_window(self) -> None: + """显示主窗口""" + self.show_window_requested.emit() + + def _on_quit(self) -> None: + """退出应用""" + self.quit_requested.emit() + + def _on_tray_activated(self, reason: QSystemTrayIcon.ActivationReason) -> None: + """托盘图标激活""" + if reason == QSystemTrayIcon.ActivationReason.DoubleClick: + self.show_window_requested.emit() + + # ==================== 公共方法 ==================== + + def is_available(self) -> bool: + """ + 检查系统托盘是否可用 + + Returns: + 是否可用 + """ + return self._tray_icon is not None + + def show(self) -> None: + """显示托盘图标""" + if self._tray_icon: + self._tray_icon.show() + + def hide(self) -> None: + """隐藏托盘图标""" + if self._tray_icon: + self._tray_icon.hide() + + def show_message(self, title: str, message: str, + icon: QSystemTrayIcon.MessageIcon = QSystemTrayIcon.MessageIcon.Information, + duration: int = 3000) -> None: + """ + 显示通知消息 + + 实现新消息通知 (需求 9.3) + WHEN 有新消息到达 THEN P2P_Client SHALL 在系统托盘显示通知提醒 + + Args: + title: 通知标题 + message: 通知内容 + icon: 图标类型 + duration: 显示时长(毫秒) + """ + if self._tray_icon: + self._tray_icon.showMessage(title, message, icon, duration) + logger.debug(f"Tray notification: {title} - {message}") + + def set_status(self, status: str) -> None: + """ + 设置状态显示 + + Args: + status: 状态文本 + """ + if self._status_action: + self._status_action.setText(f"状态: {status}") + + if self._tray_icon: + self._tray_icon.setToolTip(f"P2P 通信应用 - {status}") + + def set_unread_count(self, count: int) -> None: + """ + 设置未读消息数 + + Args: + count: 未读消息数 + """ + self._unread_count = count + + if self._tray_icon: + if count > 0: + self._tray_icon.setToolTip(f"P2P 通信应用 - {count} 条未读消息") + else: + self._tray_icon.setToolTip("P2P 通信应用") + + def increment_unread(self) -> None: + """增加未读消息数""" + self.set_unread_count(self._unread_count + 1) + + def clear_unread(self) -> None: + """清除未读消息数""" + self.set_unread_count(0) diff --git a/client/ui/voice_call_widget.py b/client/ui/voice_call_widget.py new file mode 100644 index 0000000..bc96d8d --- /dev/null +++ b/client/ui/voice_call_widget.py @@ -0,0 +1,393 @@ +# P2P Network Communication - Voice Call Widget +""" +语音通话界面组件 +实现来电提示、通话中界面和通话状态显示 + +需求: 7.2, 7.8 +""" + +import logging +from typing import Optional + +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, + QFrame, QDialog +) +from PyQt6.QtCore import Qt, pyqtSignal, QTimer + +from client.voice_chat import CallState, NetworkQuality + +logger = logging.getLogger(__name__) + + +class IncomingCallDialog(QDialog): + """ + 来电提示对话框 + + 实现来电提示界面 (需求 7.2) + WHEN 收到语音通话邀请 THEN P2P_Client SHALL 显示来电提示并提供接听或拒绝选项 + """ + + accepted = pyqtSignal() + rejected = pyqtSignal() + + def __init__(self, caller_name: str, parent=None): + super().__init__(parent) + self.caller_name = caller_name + self._setup_ui() + + # 响铃动画定时器 + self._ring_timer = QTimer() + self._ring_timer.setInterval(500) + self._ring_timer.timeout.connect(self._animate_ring) + self._ring_state = False + self._ring_timer.start() + + def _setup_ui(self) -> None: + """设置UI""" + self.setWindowTitle("来电") + self.setFixedSize(300, 200) + self.setModal(True) + self.setWindowFlags(self.windowFlags() | Qt.WindowType.WindowStaysOnTopHint) + + layout = QVBoxLayout(self) + layout.setSpacing(20) + layout.setContentsMargins(30, 30, 30, 30) + + # 来电图标 + self._icon_label = QLabel("📞") + self._icon_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._icon_label.setStyleSheet("font-size: 48px;") + layout.addWidget(self._icon_label) + + # 来电信息 + info_label = QLabel(f"{self.caller_name}\n正在呼叫您...") + info_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + info_label.setStyleSheet("font-size: 16px;") + layout.addWidget(info_label) + + # 按钮 + button_layout = QHBoxLayout() + button_layout.setSpacing(20) + + self._reject_btn = QPushButton("拒绝") + self._reject_btn.setFixedSize(80, 40) + self._reject_btn.setStyleSheet(""" + QPushButton { + background-color: #f44336; + color: white; + border: none; + border-radius: 20px; + font-size: 14px; + } + QPushButton:hover { + background-color: #d32f2f; + } + """) + self._reject_btn.clicked.connect(self._on_reject) + button_layout.addWidget(self._reject_btn) + + self._accept_btn = QPushButton("接听") + self._accept_btn.setFixedSize(80, 40) + self._accept_btn.setStyleSheet(""" + QPushButton { + background-color: #4caf50; + color: white; + border: none; + border-radius: 20px; + font-size: 14px; + } + QPushButton:hover { + background-color: #388e3c; + } + """) + self._accept_btn.clicked.connect(self._on_accept) + button_layout.addWidget(self._accept_btn) + + layout.addLayout(button_layout) + + def _animate_ring(self) -> None: + """响铃动画""" + self._ring_state = not self._ring_state + self._icon_label.setText("📞" if self._ring_state else "📱") + + def _on_accept(self) -> None: + """接听""" + self._ring_timer.stop() + self.accepted.emit() + self.accept() + + def _on_reject(self) -> None: + """拒绝""" + self._ring_timer.stop() + self.rejected.emit() + self.reject() + + def closeEvent(self, event) -> None: + """关闭事件""" + self._ring_timer.stop() + self.rejected.emit() + super().closeEvent(event) + + +class VoiceCallWidget(QWidget): + """ + 语音通话界面组件 + + 实现通话中界面 (需求 7.8) + 实现通话状态显示 (需求 7.8) + WHEN 语音通话进行中 THEN P2P_Client SHALL 显示通话时长和网络质量指示 + """ + + # 信号定义 + mute_toggled = pyqtSignal(bool) + end_call_requested = pyqtSignal() + + def __init__(self, parent=None): + super().__init__(parent) + + self._peer_name: str = "" + self._is_muted: bool = False + self._call_duration: int = 0 + + self._setup_ui() + + # 通话计时器 + self._duration_timer = QTimer() + self._duration_timer.setInterval(1000) + self._duration_timer.timeout.connect(self._update_duration) + + logger.info("VoiceCallWidget initialized") + + def _setup_ui(self) -> None: + """设置UI""" + self.setStyleSheet("background-color: #2d2d2d;") + + layout = QVBoxLayout(self) + layout.setSpacing(20) + layout.setContentsMargins(30, 30, 30, 30) + + layout.addStretch() + + # 通话对象头像 + avatar_label = QLabel("👤") + avatar_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + avatar_label.setStyleSheet(""" + font-size: 64px; + background-color: #4a90d9; + border-radius: 50px; + padding: 20px; + """) + avatar_label.setFixedSize(100, 100) + + avatar_container = QHBoxLayout() + avatar_container.addStretch() + avatar_container.addWidget(avatar_label) + avatar_container.addStretch() + layout.addLayout(avatar_container) + + # 通话对象名称 + self._peer_name_label = QLabel("未知用户") + self._peer_name_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._peer_name_label.setStyleSheet("color: white; font-size: 24px; font-weight: bold;") + layout.addWidget(self._peer_name_label) + + # 通话状态 + self._status_label = QLabel("正在连接...") + self._status_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._status_label.setStyleSheet("color: #aaa; font-size: 14px;") + layout.addWidget(self._status_label) + + # 通话时长 + self._duration_label = QLabel("00:00") + self._duration_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._duration_label.setStyleSheet("color: white; font-size: 32px;") + layout.addWidget(self._duration_label) + + # 网络质量 + self._quality_label = QLabel("网络质量: --") + self._quality_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._quality_label.setStyleSheet("color: #aaa; font-size: 12px;") + layout.addWidget(self._quality_label) + + layout.addStretch() + + # 控制按钮 + button_layout = QHBoxLayout() + button_layout.setSpacing(30) + + button_layout.addStretch() + + # 静音按钮 + self._mute_btn = QPushButton("🎤") + self._mute_btn.setFixedSize(60, 60) + self._mute_btn.setStyleSheet(""" + QPushButton { + background-color: #555; + color: white; + border: none; + border-radius: 30px; + font-size: 24px; + } + QPushButton:hover { + background-color: #666; + } + QPushButton:checked { + background-color: #f44336; + } + """) + self._mute_btn.setCheckable(True) + self._mute_btn.clicked.connect(self._on_mute_clicked) + button_layout.addWidget(self._mute_btn) + + # 挂断按钮 + self._end_btn = QPushButton("📞") + self._end_btn.setFixedSize(70, 70) + self._end_btn.setStyleSheet(""" + QPushButton { + background-color: #f44336; + color: white; + border: none; + border-radius: 35px; + font-size: 28px; + } + QPushButton:hover { + background-color: #d32f2f; + } + """) + self._end_btn.clicked.connect(self._on_end_clicked) + button_layout.addWidget(self._end_btn) + + # 扬声器按钮 + self._speaker_btn = QPushButton("🔊") + self._speaker_btn.setFixedSize(60, 60) + self._speaker_btn.setStyleSheet(""" + QPushButton { + background-color: #555; + color: white; + border: none; + border-radius: 30px; + font-size: 24px; + } + QPushButton:hover { + background-color: #666; + } + """) + button_layout.addWidget(self._speaker_btn) + + button_layout.addStretch() + + layout.addLayout(button_layout) + layout.addSpacing(30) + + def _format_duration(self, seconds: int) -> str: + """格式化时长""" + minutes = seconds // 60 + secs = seconds % 60 + return f"{minutes:02d}:{secs:02d}" + + def _update_duration(self) -> None: + """更新通话时长""" + self._call_duration += 1 + self._duration_label.setText(self._format_duration(self._call_duration)) + + def _on_mute_clicked(self) -> None: + """处理静音按钮点击""" + self._is_muted = self._mute_btn.isChecked() + self._mute_btn.setText("🔇" if self._is_muted else "🎤") + self.mute_toggled.emit(self._is_muted) + + def _on_end_clicked(self) -> None: + """处理挂断按钮点击""" + self.end_call_requested.emit() + + # ==================== 公共方法 ==================== + + def start_call(self, peer_name: str) -> None: + """ + 开始通话 + + Args: + peer_name: 通话对象名称 + """ + self._peer_name = peer_name + self._peer_name_label.setText(peer_name) + self._status_label.setText("通话中") + self._call_duration = 0 + self._duration_label.setText("00:00") + self._duration_timer.start() + + logger.info(f"Call started with {peer_name}") + + def end_call(self) -> None: + """结束通话""" + self._duration_timer.stop() + self._status_label.setText("通话已结束") + + logger.info("Call ended") + + def update_state(self, state: CallState) -> None: + """ + 更新通话状态 + + Args: + state: 通话状态 + """ + state_texts = { + CallState.IDLE: "空闲", + CallState.CALLING: "正在呼叫...", + CallState.RINGING: "对方响铃中...", + CallState.CONNECTED: "通话中", + CallState.ENDING: "正在结束...", + } + self._status_label.setText(state_texts.get(state, "未知")) + + if state == CallState.CONNECTED: + self._duration_timer.start() + elif state in (CallState.IDLE, CallState.ENDING): + self._duration_timer.stop() + + def update_network_quality(self, quality: NetworkQuality) -> None: + """ + 更新网络质量显示 + + 实现网络质量指示 (需求 7.8) + WHEN 语音通话进行中 THEN P2P_Client SHALL 显示通话时长和网络质量指示 + + Args: + quality: 网络质量 + """ + quality_texts = { + NetworkQuality.EXCELLENT: ("优秀", "#4caf50"), + NetworkQuality.GOOD: ("良好", "#8bc34a"), + NetworkQuality.FAIR: ("一般", "#ffeb3b"), + NetworkQuality.POOR: ("较差", "#ff9800"), + NetworkQuality.BAD: ("很差", "#f44336"), + } + + text, color = quality_texts.get(quality, ("未知", "#999")) + self._quality_label.setText(f"网络质量: {text}") + self._quality_label.setStyleSheet(f"color: {color}; font-size: 12px;") + + def get_call_duration(self) -> int: + """ + 获取通话时长 + + Returns: + 通话时长(秒) + """ + return self._call_duration + + def reset(self) -> None: + """重置界面""" + self._duration_timer.stop() + self._peer_name = "" + self._is_muted = False + self._call_duration = 0 + + self._peer_name_label.setText("未知用户") + self._status_label.setText("正在连接...") + self._duration_label.setText("00:00") + self._quality_label.setText("网络质量: --") + self._mute_btn.setChecked(False) + self._mute_btn.setText("🎤") diff --git a/shared/__init__.py b/shared/__init__.py index e59927b..6029aa1 100644 --- a/shared/__init__.py +++ b/shared/__init__.py @@ -27,6 +27,26 @@ from shared.message_handler import ( MessageRoutingError, ) +from shared.security import ( + SecurityError, + EncryptionError, + DecryptionError, + KeyManagementError, + CertificateError, + EncryptedData, + AESCipher, + TLSManager, + MessageEncryptor, + FileEncryptor, + KeyManager, + LocalDataEncryptor, + create_message_encryptor, + create_file_encryptor, + create_local_data_encryptor, + encrypt_message, + decrypt_message, +) + __all__ = [ "MessageType", "UserStatus", @@ -44,4 +64,22 @@ __all__ = [ "MessageValidationError", "MessageSerializationError", "MessageRoutingError", + # Security + "SecurityError", + "EncryptionError", + "DecryptionError", + "KeyManagementError", + "CertificateError", + "EncryptedData", + "AESCipher", + "TLSManager", + "MessageEncryptor", + "FileEncryptor", + "KeyManager", + "LocalDataEncryptor", + "create_message_encryptor", + "create_file_encryptor", + "create_local_data_encryptor", + "encrypt_message", + "decrypt_message", ] diff --git a/shared/security.py b/shared/security.py new file mode 100644 index 0000000..10c0d9e --- /dev/null +++ b/shared/security.py @@ -0,0 +1,1037 @@ +# P2P Network Communication - Security Module +""" +安全模块 +负责传输加密、本地数据加密和密钥管理 + +需求: 10.1, 10.2, 10.3 +""" + +import base64 +import hashlib +import hmac +import json +import logging +import os +import secrets +import ssl +import struct +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Optional, Tuple, Union + +from cryptography.hazmat.primitives import hashes, padding, serialization +from cryptography.hazmat.primitives.asymmetric import rsa, padding as asym_padding +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from cryptography.hazmat.backends import default_backend +from cryptography.x509 import load_pem_x509_certificate + + +# 设置日志 +logger = logging.getLogger(__name__) + + +class SecurityError(Exception): + """安全错误基类""" + pass + + +class EncryptionError(SecurityError): + """加密错误""" + pass + + +class DecryptionError(SecurityError): + """解密错误""" + pass + + +class KeyManagementError(SecurityError): + """密钥管理错误""" + pass + + +class CertificateError(SecurityError): + """证书错误""" + pass + + +@dataclass +class EncryptedData: + """加密数据结构""" + ciphertext: bytes + iv: bytes # 初始化向量 + tag: bytes # 认证标签 (GCM模式) + salt: bytes = field(default_factory=bytes) # 用于密钥派生 + + def to_bytes(self) -> bytes: + """序列化为字节流""" + # 格式: [iv_len(2)][iv][tag_len(2)][tag][salt_len(2)][salt][ciphertext] + result = struct.pack('!H', len(self.iv)) + self.iv + result += struct.pack('!H', len(self.tag)) + self.tag + result += struct.pack('!H', len(self.salt)) + self.salt + result += self.ciphertext + return result + + @classmethod + def from_bytes(cls, data: bytes) -> "EncryptedData": + """从字节流反序列化""" + offset = 0 + + # 读取IV + iv_len = struct.unpack('!H', data[offset:offset+2])[0] + offset += 2 + iv = data[offset:offset+iv_len] + offset += iv_len + + # 读取tag + tag_len = struct.unpack('!H', data[offset:offset+2])[0] + offset += 2 + tag = data[offset:offset+tag_len] + offset += tag_len + + # 读取salt + salt_len = struct.unpack('!H', data[offset:offset+2])[0] + offset += 2 + salt = data[offset:offset+salt_len] + offset += salt_len + + # 剩余为密文 + ciphertext = data[offset:] + + return cls(ciphertext=ciphertext, iv=iv, tag=tag, salt=salt) + + def to_base64(self) -> str: + """转换为Base64字符串""" + return base64.b64encode(self.to_bytes()).decode('utf-8') + + @classmethod + def from_base64(cls, data: str) -> "EncryptedData": + """从Base64字符串创建""" + return cls.from_bytes(base64.b64decode(data)) + + + +class AESCipher: + """ + AES-256-GCM 加密器 + + 用于消息和文件的加密传输 (需求 10.1, 10.3) + 以及本地数据的加密存储 (需求 10.2) + """ + + # AES-256 密钥长度 + KEY_SIZE = 32 # 256 bits + + # GCM IV 长度 + IV_SIZE = 12 # 96 bits (推荐) + + # GCM 认证标签长度 + TAG_SIZE = 16 # 128 bits + + # PBKDF2 盐长度 + SALT_SIZE = 16 # 128 bits + + # PBKDF2 迭代次数 + PBKDF2_ITERATIONS = 100000 + + def __init__(self, key: Optional[bytes] = None): + """ + 初始化AES加密器 + + Args: + key: 256位密钥,如果为None则生成新密钥 + """ + if key is None: + self._key = self.generate_key() + else: + if len(key) != self.KEY_SIZE: + raise EncryptionError(f"Key must be {self.KEY_SIZE} bytes") + self._key = key + + @property + def key(self) -> bytes: + """获取密钥""" + return self._key + + @staticmethod + def generate_key() -> bytes: + """ + 生成随机256位密钥 + + Returns: + 32字节随机密钥 + """ + return secrets.token_bytes(AESCipher.KEY_SIZE) + + @staticmethod + def generate_iv() -> bytes: + """ + 生成随机IV + + Returns: + 12字节随机IV + """ + return secrets.token_bytes(AESCipher.IV_SIZE) + + @staticmethod + def derive_key_from_password(password: str, salt: Optional[bytes] = None) -> Tuple[bytes, bytes]: + """ + 从密码派生密钥 (PBKDF2) + + 实现密钥管理 (需求 10.2) + + Args: + password: 用户密码 + salt: 盐值,如果为None则生成新盐 + + Returns: + (密钥, 盐值) 元组 + """ + if salt is None: + salt = secrets.token_bytes(AESCipher.SALT_SIZE) + + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=AESCipher.KEY_SIZE, + salt=salt, + iterations=AESCipher.PBKDF2_ITERATIONS, + backend=default_backend() + ) + + key = kdf.derive(password.encode('utf-8')) + return key, salt + + def encrypt(self, plaintext: bytes) -> EncryptedData: + """ + 加密数据 (AES-256-GCM) + + 实现消息加密传输 (需求 10.1) + 实现文件加密传输 (需求 10.3) + + Args: + plaintext: 明文数据 + + Returns: + 加密数据对象 + + Raises: + EncryptionError: 加密失败 + """ + try: + iv = self.generate_iv() + + cipher = Cipher( + algorithms.AES(self._key), + modes.GCM(iv), + backend=default_backend() + ) + encryptor = cipher.encryptor() + + ciphertext = encryptor.update(plaintext) + encryptor.finalize() + + return EncryptedData( + ciphertext=ciphertext, + iv=iv, + tag=encryptor.tag + ) + + except Exception as e: + raise EncryptionError(f"Encryption failed: {e}") + + def decrypt(self, encrypted_data: EncryptedData) -> bytes: + """ + 解密数据 (AES-256-GCM) + + Args: + encrypted_data: 加密数据对象 + + Returns: + 解密后的明文 + + Raises: + DecryptionError: 解密失败 + """ + try: + cipher = Cipher( + algorithms.AES(self._key), + modes.GCM(encrypted_data.iv, encrypted_data.tag), + backend=default_backend() + ) + decryptor = cipher.decryptor() + + plaintext = decryptor.update(encrypted_data.ciphertext) + decryptor.finalize() + + return plaintext + + except Exception as e: + raise DecryptionError(f"Decryption failed: {e}") + + def encrypt_with_password(self, plaintext: bytes, password: str) -> EncryptedData: + """ + 使用密码加密数据 + + 实现 AES-256 加密存储 (需求 10.2) + + Args: + plaintext: 明文数据 + password: 用户密码 + + Returns: + 加密数据对象(包含盐值) + """ + key, salt = self.derive_key_from_password(password) + + # 临时使用派生密钥 + original_key = self._key + self._key = key + + try: + encrypted = self.encrypt(plaintext) + encrypted.salt = salt + return encrypted + finally: + self._key = original_key + + def decrypt_with_password(self, encrypted_data: EncryptedData, password: str) -> bytes: + """ + 使用密码解密数据 + + Args: + encrypted_data: 加密数据对象 + password: 用户密码 + + Returns: + 解密后的明文 + """ + key, _ = self.derive_key_from_password(password, encrypted_data.salt) + + # 临时使用派生密钥 + original_key = self._key + self._key = key + + try: + return self.decrypt(encrypted_data) + finally: + self._key = original_key + + + +class TLSManager: + """ + TLS/SSL 连接管理器 + + 实现 TLS/SSL 连接 (需求 10.1) + """ + + def __init__(self, cert_file: Optional[str] = None, + key_file: Optional[str] = None, + ca_file: Optional[str] = None): + """ + 初始化TLS管理器 + + Args: + cert_file: 证书文件路径 + key_file: 私钥文件路径 + ca_file: CA证书文件路径 + """ + self.cert_file = cert_file + self.key_file = key_file + self.ca_file = ca_file + + def create_server_ssl_context(self) -> ssl.SSLContext: + """ + 创建服务器SSL上下文 + + 实现 TLS/SSL 连接 (需求 10.1) + + Returns: + SSL上下文对象 + + Raises: + CertificateError: 证书配置错误 + """ + try: + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.minimum_version = ssl.TLSVersion.TLSv1_2 + + # 设置安全选项 + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_TLSv1 + context.options |= ssl.OP_NO_TLSv1_1 + + # 加载证书和私钥 + if self.cert_file and self.key_file: + context.load_cert_chain(self.cert_file, self.key_file) + else: + # 生成自签名证书(仅用于开发/测试) + logger.warning("No certificate provided, using self-signed certificate") + self._generate_self_signed_cert() + if self.cert_file and self.key_file: + context.load_cert_chain(self.cert_file, self.key_file) + + # 设置密码套件 + context.set_ciphers('ECDHE+AESGCM:DHE+AESGCM:ECDHE+CHACHA20:DHE+CHACHA20') + + return context + + except ssl.SSLError as e: + raise CertificateError(f"Failed to create server SSL context: {e}") + except FileNotFoundError as e: + raise CertificateError(f"Certificate file not found: {e}") + + def create_client_ssl_context(self, verify: bool = True) -> ssl.SSLContext: + """ + 创建客户端SSL上下文 + + 实现 TLS/SSL 连接 (需求 10.1) + + Args: + verify: 是否验证服务器证书 + + Returns: + SSL上下文对象 + """ + try: + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.minimum_version = ssl.TLSVersion.TLSv1_2 + + # 设置安全选项 + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_TLSv1 + context.options |= ssl.OP_NO_TLSv1_1 + + if verify: + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + + # 加载CA证书 + if self.ca_file: + context.load_verify_locations(self.ca_file) + else: + # 使用系统默认CA + context.load_default_certs() + else: + # 不验证证书(仅用于开发/测试) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + logger.warning("SSL certificate verification disabled") + + # 设置密码套件 + context.set_ciphers('ECDHE+AESGCM:DHE+AESGCM:ECDHE+CHACHA20:DHE+CHACHA20') + + return context + + except ssl.SSLError as e: + raise CertificateError(f"Failed to create client SSL context: {e}") + + def _generate_self_signed_cert(self) -> None: + """ + 生成自签名证书(仅用于开发/测试) + """ + from cryptography import x509 + from cryptography.x509.oid import NameOID + from datetime import timedelta + + # 生成私钥 + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend() + ) + + # 创建证书 + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, "CN"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Beijing"), + x509.NameAttribute(NameOID.LOCALITY_NAME, "Beijing"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "P2P Chat"), + x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), + ]) + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.utcnow()) + .not_valid_after(datetime.utcnow() + timedelta(days=365)) + .add_extension( + x509.SubjectAlternativeName([ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ]), + critical=False, + ) + .sign(private_key, hashes.SHA256(), default_backend()) + ) + + # 保存证书和私钥 + cert_dir = Path("certs") + cert_dir.mkdir(exist_ok=True) + + self.cert_file = str(cert_dir / "server.crt") + self.key_file = str(cert_dir / "server.key") + + with open(self.cert_file, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + + with open(self.key_file, "wb") as f: + f.write(private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption() + )) + + logger.info(f"Generated self-signed certificate: {self.cert_file}") + + + +class MessageEncryptor: + """ + 消息加密器 + + 实现消息加密传输 (需求 10.1) + """ + + def __init__(self, cipher: Optional[AESCipher] = None): + """ + 初始化消息加密器 + + Args: + cipher: AES加密器,如果为None则创建新的 + """ + self._cipher = cipher or AESCipher() + + @property + def key(self) -> bytes: + """获取加密密钥""" + return self._cipher.key + + def set_key(self, key: bytes) -> None: + """ + 设置加密密钥 + + Args: + key: 256位密钥 + """ + self._cipher = AESCipher(key) + + def encrypt_message(self, message_data: bytes) -> bytes: + """ + 加密消息数据 + + 实现消息加密传输 (需求 10.1) + + Args: + message_data: 消息数据(序列化后的字节流) + + Returns: + 加密后的字节流 + """ + encrypted = self._cipher.encrypt(message_data) + return encrypted.to_bytes() + + def decrypt_message(self, encrypted_data: bytes) -> bytes: + """ + 解密消息数据 + + Args: + encrypted_data: 加密的字节流 + + Returns: + 解密后的消息数据 + """ + encrypted = EncryptedData.from_bytes(encrypted_data) + return self._cipher.decrypt(encrypted) + + +class FileEncryptor: + """ + 文件加密器 + + 实现文件加密传输 (需求 10.3) + """ + + # 文件加密块大小 + BLOCK_SIZE = 64 * 1024 # 64KB + + def __init__(self, cipher: Optional[AESCipher] = None): + """ + 初始化文件加密器 + + Args: + cipher: AES加密器 + """ + self._cipher = cipher or AESCipher() + + @property + def key(self) -> bytes: + """获取加密密钥""" + return self._cipher.key + + def set_key(self, key: bytes) -> None: + """ + 设置加密密钥 + + Args: + key: 256位密钥 + """ + self._cipher = AESCipher(key) + + def encrypt_file(self, input_path: str, output_path: str) -> bool: + """ + 加密文件 + + 实现文件加密传输 (需求 10.3) + + Args: + input_path: 输入文件路径 + output_path: 输出文件路径 + + Returns: + 加密成功返回True + + Raises: + EncryptionError: 加密失败 + """ + try: + # 读取整个文件 + with open(input_path, 'rb') as f: + plaintext = f.read() + + # 加密 + encrypted = self._cipher.encrypt(plaintext) + + # 写入加密文件 + with open(output_path, 'wb') as f: + f.write(encrypted.to_bytes()) + + logger.info(f"File encrypted: {input_path} -> {output_path}") + return True + + except Exception as e: + raise EncryptionError(f"Failed to encrypt file: {e}") + + def decrypt_file(self, input_path: str, output_path: str) -> bool: + """ + 解密文件 + + Args: + input_path: 加密文件路径 + output_path: 输出文件路径 + + Returns: + 解密成功返回True + + Raises: + DecryptionError: 解密失败 + """ + try: + # 读取加密文件 + with open(input_path, 'rb') as f: + encrypted_data = f.read() + + # 解密 + encrypted = EncryptedData.from_bytes(encrypted_data) + plaintext = self._cipher.decrypt(encrypted) + + # 写入解密文件 + with open(output_path, 'wb') as f: + f.write(plaintext) + + logger.info(f"File decrypted: {input_path} -> {output_path}") + return True + + except Exception as e: + raise DecryptionError(f"Failed to decrypt file: {e}") + + def encrypt_chunk(self, chunk_data: bytes) -> bytes: + """ + 加密文件块 + + 用于分块传输时的加密 + + Args: + chunk_data: 文件块数据 + + Returns: + 加密后的数据 + """ + encrypted = self._cipher.encrypt(chunk_data) + return encrypted.to_bytes() + + def decrypt_chunk(self, encrypted_chunk: bytes) -> bytes: + """ + 解密文件块 + + Args: + encrypted_chunk: 加密的文件块 + + Returns: + 解密后的数据 + """ + encrypted = EncryptedData.from_bytes(encrypted_chunk) + return self._cipher.decrypt(encrypted) + + + +class KeyManager: + """ + 密钥管理器 + + 实现密钥管理 (需求 10.2) + """ + + # 密钥存储目录 + KEY_DIR = "keys" + + def __init__(self, key_dir: Optional[str] = None): + """ + 初始化密钥管理器 + + Args: + key_dir: 密钥存储目录 + """ + self._key_dir = Path(key_dir or self.KEY_DIR) + self._key_dir.mkdir(parents=True, exist_ok=True) + + # 内存中的密钥缓存 + self._key_cache: dict = {} + + def generate_key_pair(self) -> Tuple[bytes, bytes]: + """ + 生成RSA密钥对 + + Returns: + (私钥, 公钥) 元组(PEM格式) + """ + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend() + ) + + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ) + + public_pem = private_key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + return private_pem, public_pem + + def save_key(self, key_id: str, key_data: bytes, password: Optional[str] = None) -> bool: + """ + 保存密钥到文件 + + 实现密钥管理 (需求 10.2) + + Args: + key_id: 密钥标识 + key_data: 密钥数据 + password: 可选的加密密码 + + Returns: + 保存成功返回True + """ + try: + key_file = self._key_dir / f"{key_id}.key" + + if password: + # 使用密码加密密钥 + cipher = AESCipher() + encrypted = cipher.encrypt_with_password(key_data, password) + with open(key_file, 'wb') as f: + f.write(encrypted.to_bytes()) + else: + # 直接保存(不推荐用于生产环境) + with open(key_file, 'wb') as f: + f.write(key_data) + + logger.info(f"Key saved: {key_id}") + return True + + except Exception as e: + logger.error(f"Failed to save key {key_id}: {e}") + return False + + def load_key(self, key_id: str, password: Optional[str] = None) -> Optional[bytes]: + """ + 从文件加载密钥 + + Args: + key_id: 密钥标识 + password: 解密密码 + + Returns: + 密钥数据,如果加载失败返回None + """ + try: + key_file = self._key_dir / f"{key_id}.key" + + if not key_file.exists(): + logger.error(f"Key file not found: {key_id}") + return None + + with open(key_file, 'rb') as f: + data = f.read() + + if password: + # 解密密钥 + cipher = AESCipher() + encrypted = EncryptedData.from_bytes(data) + return cipher.decrypt_with_password(encrypted, password) + else: + return data + + except Exception as e: + logger.error(f"Failed to load key {key_id}: {e}") + return None + + def delete_key(self, key_id: str) -> bool: + """ + 删除密钥 + + Args: + key_id: 密钥标识 + + Returns: + 删除成功返回True + """ + try: + key_file = self._key_dir / f"{key_id}.key" + + if key_file.exists(): + key_file.unlink() + logger.info(f"Key deleted: {key_id}") + + # 从缓存中移除 + self._key_cache.pop(key_id, None) + + return True + + except Exception as e: + logger.error(f"Failed to delete key {key_id}: {e}") + return False + + def get_or_create_session_key(self, session_id: str) -> bytes: + """ + 获取或创建会话密钥 + + Args: + session_id: 会话标识 + + Returns: + 会话密钥 + """ + if session_id in self._key_cache: + return self._key_cache[session_id] + + # 生成新的会话密钥 + key = AESCipher.generate_key() + self._key_cache[session_id] = key + + return key + + def clear_session_key(self, session_id: str) -> None: + """ + 清除会话密钥 + + Args: + session_id: 会话标识 + """ + self._key_cache.pop(session_id, None) + + + +class LocalDataEncryptor: + """ + 本地数据加密器 + + 实现 AES-256 加密存储 (需求 10.2) + """ + + def __init__(self, password: Optional[str] = None, key: Optional[bytes] = None): + """ + 初始化本地数据加密器 + + Args: + password: 用户密码(用于派生密钥) + key: 直接提供的密钥 + """ + self._cipher = AESCipher() + self._password = password + self._salt: Optional[bytes] = None + + if key: + self._cipher = AESCipher(key) + elif password: + # 从密码派生密钥 + key, self._salt = AESCipher.derive_key_from_password(password) + self._cipher = AESCipher(key) + + def encrypt_data(self, data: Union[str, bytes, dict, list]) -> str: + """ + 加密数据 + + 实现对本地数据进行加密存储 (需求 10.2) + + Args: + data: 要加密的数据(字符串、字节、字典或列表) + + Returns: + Base64编码的加密数据 + """ + # 转换为字节 + if isinstance(data, str): + plaintext = data.encode('utf-8') + elif isinstance(data, (dict, list)): + plaintext = json.dumps(data, ensure_ascii=False).encode('utf-8') + else: + plaintext = data + + encrypted = self._cipher.encrypt(plaintext) + + # 如果有盐值,添加到加密数据中 + if self._salt: + encrypted.salt = self._salt + + return encrypted.to_base64() + + def decrypt_data(self, encrypted_data: str, as_json: bool = False) -> Union[bytes, dict]: + """ + 解密数据 + + Args: + encrypted_data: Base64编码的加密数据 + as_json: 是否解析为JSON + + Returns: + 解密后的数据 + """ + encrypted = EncryptedData.from_base64(encrypted_data) + plaintext = self._cipher.decrypt(encrypted) + + if as_json: + return json.loads(plaintext.decode('utf-8')) + + return plaintext + + def encrypt_chat_history(self, messages: list) -> str: + """ + 加密聊天记录 + + 实现对本地数据进行加密存储 (需求 10.2) + + Args: + messages: 聊天消息列表 + + Returns: + 加密后的数据 + """ + return self.encrypt_data(messages) + + def decrypt_chat_history(self, encrypted_data: str) -> list: + """ + 解密聊天记录 + + Args: + encrypted_data: 加密的聊天记录 + + Returns: + 聊天消息列表 + """ + return self.decrypt_data(encrypted_data, as_json=True) + + def save_encrypted_file(self, data: Union[str, bytes, dict], file_path: str) -> bool: + """ + 保存加密数据到文件 + + Args: + data: 要保存的数据 + file_path: 文件路径 + + Returns: + 保存成功返回True + """ + try: + encrypted = self.encrypt_data(data) + + # 确保目录存在 + Path(file_path).parent.mkdir(parents=True, exist_ok=True) + + with open(file_path, 'w', encoding='utf-8') as f: + f.write(encrypted) + + return True + + except Exception as e: + logger.error(f"Failed to save encrypted file: {e}") + return False + + def load_encrypted_file(self, file_path: str, as_json: bool = False) -> Optional[Union[bytes, dict]]: + """ + 从文件加载加密数据 + + Args: + file_path: 文件路径 + as_json: 是否解析为JSON + + Returns: + 解密后的数据,如果失败返回None + """ + try: + with open(file_path, 'r', encoding='utf-8') as f: + encrypted = f.read() + + return self.decrypt_data(encrypted, as_json=as_json) + + except FileNotFoundError: + logger.error(f"Encrypted file not found: {file_path}") + return None + except Exception as e: + logger.error(f"Failed to load encrypted file: {e}") + return None + + +# 需要导入ipaddress模块用于自签名证书 +import ipaddress + + +# 便捷函数 +def create_message_encryptor(key: Optional[bytes] = None) -> MessageEncryptor: + """创建消息加密器""" + cipher = AESCipher(key) if key else AESCipher() + return MessageEncryptor(cipher) + + +def create_file_encryptor(key: Optional[bytes] = None) -> FileEncryptor: + """创建文件加密器""" + cipher = AESCipher(key) if key else AESCipher() + return FileEncryptor(cipher) + + +def create_local_data_encryptor(password: str) -> LocalDataEncryptor: + """创建本地数据加密器""" + return LocalDataEncryptor(password=password) + + +def encrypt_message(message_data: bytes, key: bytes) -> bytes: + """快速加密消息""" + encryptor = MessageEncryptor(AESCipher(key)) + return encryptor.encrypt_message(message_data) + + +def decrypt_message(encrypted_data: bytes, key: bytes) -> bytes: + """快速解密消息""" + encryptor = MessageEncryptor(AESCipher(key)) + return encryptor.decrypt_message(encrypted_data) diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..59c9825 --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,436 @@ +# P2P Network Communication - Security Module Tests +""" +安全模块测试 +测试传输加密、本地数据加密和密钥管理功能 + +需求: 10.1, 10.2, 10.3 +""" + +import os +import tempfile +import pytest +from pathlib import Path + +from shared.security import ( + AESCipher, + EncryptedData, + MessageEncryptor, + FileEncryptor, + KeyManager, + LocalDataEncryptor, + TLSManager, + EncryptionError, + DecryptionError, + create_message_encryptor, + create_file_encryptor, + create_local_data_encryptor, + encrypt_message, + decrypt_message, +) + + +class TestAESCipher: + """AES-256-GCM 加密器测试""" + + def test_generate_key(self): + """测试密钥生成""" + key = AESCipher.generate_key() + assert len(key) == 32 # 256 bits + + # 确保每次生成的密钥不同 + key2 = AESCipher.generate_key() + assert key != key2 + + def test_generate_iv(self): + """测试IV生成""" + iv = AESCipher.generate_iv() + assert len(iv) == 12 # 96 bits + + def test_encrypt_decrypt(self): + """测试加密解密往返""" + cipher = AESCipher() + plaintext = b"Hello, World! This is a test message." + + encrypted = cipher.encrypt(plaintext) + decrypted = cipher.decrypt(encrypted) + + assert decrypted == plaintext + + def test_encrypt_decrypt_empty(self): + """测试空数据加密解密""" + cipher = AESCipher() + plaintext = b"" + + encrypted = cipher.encrypt(plaintext) + decrypted = cipher.decrypt(encrypted) + + assert decrypted == plaintext + + def test_encrypt_decrypt_large_data(self): + """测试大数据加密解密""" + cipher = AESCipher() + plaintext = os.urandom(1024 * 1024) # 1MB + + encrypted = cipher.encrypt(plaintext) + decrypted = cipher.decrypt(encrypted) + + assert decrypted == plaintext + + def test_encrypt_decrypt_unicode(self): + """测试Unicode数据加密解密""" + cipher = AESCipher() + plaintext = "你好,世界!这是一条测试消息。🎉".encode('utf-8') + + encrypted = cipher.encrypt(plaintext) + decrypted = cipher.decrypt(encrypted) + + assert decrypted == plaintext + + def test_different_keys_produce_different_ciphertext(self): + """测试不同密钥产生不同密文""" + cipher1 = AESCipher() + cipher2 = AESCipher() + plaintext = b"Same message" + + encrypted1 = cipher1.encrypt(plaintext) + encrypted2 = cipher2.encrypt(plaintext) + + assert encrypted1.ciphertext != encrypted2.ciphertext + + def test_wrong_key_fails_decryption(self): + """测试错误密钥解密失败""" + cipher1 = AESCipher() + cipher2 = AESCipher() + plaintext = b"Secret message" + + encrypted = cipher1.encrypt(plaintext) + + with pytest.raises(DecryptionError): + cipher2.decrypt(encrypted) + + def test_derive_key_from_password(self): + """测试从密码派生密钥""" + password = "my_secure_password" + + key1, salt1 = AESCipher.derive_key_from_password(password) + assert len(key1) == 32 + assert len(salt1) == 16 + + # 相同密码和盐应产生相同密钥 + key2, _ = AESCipher.derive_key_from_password(password, salt1) + assert key1 == key2 + + # 不同盐应产生不同密钥 + key3, salt3 = AESCipher.derive_key_from_password(password) + assert key1 != key3 + + def test_encrypt_with_password(self): + """测试使用密码加密""" + cipher = AESCipher() + password = "test_password" + plaintext = b"Secret data" + + encrypted = cipher.encrypt_with_password(plaintext, password) + assert encrypted.salt # 应包含盐值 + + decrypted = cipher.decrypt_with_password(encrypted, password) + assert decrypted == plaintext + + def test_wrong_password_fails(self): + """测试错误密码解密失败""" + cipher = AESCipher() + plaintext = b"Secret data" + + encrypted = cipher.encrypt_with_password(plaintext, "correct_password") + + with pytest.raises(DecryptionError): + cipher.decrypt_with_password(encrypted, "wrong_password") + + +class TestEncryptedData: + """加密数据结构测试""" + + def test_to_bytes_from_bytes(self): + """测试序列化和反序列化""" + original = EncryptedData( + ciphertext=b"encrypted_content", + iv=b"123456789012", + tag=b"1234567890123456", + salt=b"salt_value_here!" + ) + + serialized = original.to_bytes() + restored = EncryptedData.from_bytes(serialized) + + assert restored.ciphertext == original.ciphertext + assert restored.iv == original.iv + assert restored.tag == original.tag + assert restored.salt == original.salt + + def test_to_base64_from_base64(self): + """测试Base64编码和解码""" + original = EncryptedData( + ciphertext=b"encrypted_content", + iv=b"123456789012", + tag=b"1234567890123456" + ) + + base64_str = original.to_base64() + restored = EncryptedData.from_base64(base64_str) + + assert restored.ciphertext == original.ciphertext + assert restored.iv == original.iv + assert restored.tag == original.tag + + +class TestMessageEncryptor: + """消息加密器测试""" + + def test_encrypt_decrypt_message(self): + """测试消息加密解密""" + encryptor = MessageEncryptor() + message_data = b'{"type": "text", "content": "Hello!"}' + + encrypted = encryptor.encrypt_message(message_data) + decrypted = encryptor.decrypt_message(encrypted) + + assert decrypted == message_data + + def test_set_key(self): + """测试设置密钥""" + key = AESCipher.generate_key() + encryptor = MessageEncryptor() + encryptor.set_key(key) + + assert encryptor.key == key + + def test_shared_key_encryption(self): + """测试共享密钥加密""" + key = AESCipher.generate_key() + + encryptor1 = MessageEncryptor(AESCipher(key)) + encryptor2 = MessageEncryptor(AESCipher(key)) + + message = b"Shared secret message" + + encrypted = encryptor1.encrypt_message(message) + decrypted = encryptor2.decrypt_message(encrypted) + + assert decrypted == message + + +class TestFileEncryptor: + """文件加密器测试""" + + def test_encrypt_decrypt_file(self): + """测试文件加密解密""" + encryptor = FileEncryptor() + + with tempfile.TemporaryDirectory() as tmpdir: + # 创建测试文件 + input_file = Path(tmpdir) / "test.txt" + encrypted_file = Path(tmpdir) / "test.enc" + output_file = Path(tmpdir) / "test_decrypted.txt" + + original_content = b"This is test file content.\n" * 100 + input_file.write_bytes(original_content) + + # 加密 + encryptor.encrypt_file(str(input_file), str(encrypted_file)) + assert encrypted_file.exists() + + # 解密 + encryptor.decrypt_file(str(encrypted_file), str(output_file)) + assert output_file.exists() + + # 验证内容 + decrypted_content = output_file.read_bytes() + assert decrypted_content == original_content + + def test_encrypt_decrypt_chunk(self): + """测试文件块加密解密""" + encryptor = FileEncryptor() + chunk_data = os.urandom(64 * 1024) # 64KB + + encrypted = encryptor.encrypt_chunk(chunk_data) + decrypted = encryptor.decrypt_chunk(encrypted) + + assert decrypted == chunk_data + + +class TestKeyManager: + """密钥管理器测试""" + + def test_generate_key_pair(self): + """测试生成密钥对""" + with tempfile.TemporaryDirectory() as tmpdir: + manager = KeyManager(tmpdir) + private_key, public_key = manager.generate_key_pair() + + assert b"PRIVATE KEY" in private_key + assert b"PUBLIC KEY" in public_key + + def test_save_load_key(self): + """测试保存和加载密钥""" + with tempfile.TemporaryDirectory() as tmpdir: + manager = KeyManager(tmpdir) + key_data = AESCipher.generate_key() + + # 保存 + assert manager.save_key("test_key", key_data) + + # 加载 + loaded = manager.load_key("test_key") + assert loaded == key_data + + def test_save_load_key_with_password(self): + """测试使用密码保存和加载密钥""" + with tempfile.TemporaryDirectory() as tmpdir: + manager = KeyManager(tmpdir) + key_data = AESCipher.generate_key() + password = "secure_password" + + # 保存 + assert manager.save_key("protected_key", key_data, password) + + # 加载 + loaded = manager.load_key("protected_key", password) + assert loaded == key_data + + def test_delete_key(self): + """测试删除密钥""" + with tempfile.TemporaryDirectory() as tmpdir: + manager = KeyManager(tmpdir) + key_data = AESCipher.generate_key() + + manager.save_key("to_delete", key_data) + assert manager.delete_key("to_delete") + + # 验证已删除 + assert manager.load_key("to_delete") is None + + def test_session_key_management(self): + """测试会话密钥管理""" + with tempfile.TemporaryDirectory() as tmpdir: + manager = KeyManager(tmpdir) + + # 获取或创建会话密钥 + key1 = manager.get_or_create_session_key("session1") + assert len(key1) == 32 + + # 再次获取应返回相同密钥 + key2 = manager.get_or_create_session_key("session1") + assert key1 == key2 + + # 不同会话应有不同密钥 + key3 = manager.get_or_create_session_key("session2") + assert key1 != key3 + + # 清除会话密钥 + manager.clear_session_key("session1") + key4 = manager.get_or_create_session_key("session1") + assert key1 != key4 # 应该是新密钥 + + +class TestLocalDataEncryptor: + """本地数据加密器测试""" + + def test_encrypt_decrypt_string(self): + """测试字符串加密解密""" + encryptor = LocalDataEncryptor(password="test_password") + original = "Hello, World!" + + encrypted = encryptor.encrypt_data(original) + decrypted = encryptor.decrypt_data(encrypted) + + assert decrypted.decode('utf-8') == original + + def test_encrypt_decrypt_dict(self): + """测试字典加密解密""" + encryptor = LocalDataEncryptor(password="test_password") + original = {"name": "Test", "value": 123, "nested": {"key": "value"}} + + encrypted = encryptor.encrypt_data(original) + decrypted = encryptor.decrypt_data(encrypted, as_json=True) + + assert decrypted == original + + def test_encrypt_decrypt_chat_history(self): + """测试聊天记录加密解密""" + encryptor = LocalDataEncryptor(password="chat_password") + messages = [ + {"sender": "user1", "content": "Hello", "timestamp": 1234567890}, + {"sender": "user2", "content": "Hi there!", "timestamp": 1234567891}, + ] + + encrypted = encryptor.encrypt_chat_history(messages) + decrypted = encryptor.decrypt_chat_history(encrypted) + + assert decrypted == messages + + def test_save_load_encrypted_file(self): + """测试保存和加载加密文件""" + with tempfile.TemporaryDirectory() as tmpdir: + encryptor = LocalDataEncryptor(password="file_password") + file_path = Path(tmpdir) / "encrypted_data.enc" + + original_data = {"key": "value", "list": [1, 2, 3]} + + # 保存 + assert encryptor.save_encrypted_file(original_data, str(file_path)) + assert file_path.exists() + + # 加载 + loaded = encryptor.load_encrypted_file(str(file_path), as_json=True) + assert loaded == original_data + + +class TestConvenienceFunctions: + """便捷函数测试""" + + def test_create_message_encryptor(self): + """测试创建消息加密器""" + encryptor = create_message_encryptor() + assert encryptor is not None + assert len(encryptor.key) == 32 + + def test_create_file_encryptor(self): + """测试创建文件加密器""" + encryptor = create_file_encryptor() + assert encryptor is not None + assert len(encryptor.key) == 32 + + def test_create_local_data_encryptor(self): + """测试创建本地数据加密器""" + encryptor = create_local_data_encryptor("password") + assert encryptor is not None + + def test_encrypt_decrypt_message_functions(self): + """测试快速加密解密函数""" + key = AESCipher.generate_key() + message = b"Quick encryption test" + + encrypted = encrypt_message(message, key) + decrypted = decrypt_message(encrypted, key) + + assert decrypted == message + + +class TestTLSManager: + """TLS管理器测试""" + + def test_create_client_ssl_context_no_verify(self): + """测试创建客户端SSL上下文(不验证)""" + manager = TLSManager() + context = manager.create_client_ssl_context(verify=False) + + assert context is not None + assert context.verify_mode.name == "CERT_NONE" + + def test_create_client_ssl_context_with_verify(self): + """测试创建客户端SSL上下文(验证)""" + manager = TLSManager() + context = manager.create_client_ssl_context(verify=True) + + assert context is not None + assert context.verify_mode.name == "CERT_REQUIRED"