You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

625 lines
27 KiB

"""
服务器核心逻辑
"""
import socket
import struct
import threading
import json
import time
import logging
from datetime import datetime, timezone, timedelta
def beijing_now():
"""返回北京时间 datetime 对象"""
return datetime.now(timezone(timedelta(hours=8))).replace(tzinfo=None)
def beijing_now_str(fmt='%Y-%m-%d %H:%M:%S'):
"""返回北京时间字符串"""
return beijing_now().strftime(fmt)
from database import Database
class ClientHandler(threading.Thread):
"""每个客户端连接对应一个处理线程"""
def __init__(self, client_socket, client_address, server):
super().__init__(daemon=True)
self.client_socket = client_socket
self.client_address = client_address
self.server = server
self.user_id = None
self.username = None
self.running = True
self.client_socket.settimeout(60)
# ── 收发 ──────────────────────────────────────────────
def _recv_exact(self, n):
buf = b''
while len(buf) < n:
if not self.running:
return None
try:
chunk = self.client_socket.recv(n - len(buf))
if not chunk:
return None
buf += chunk
except socket.timeout:
continue
return buf
def send_message(self, message):
try:
data = json.dumps(message, ensure_ascii=False).encode('utf-8')
self.client_socket.sendall(struct.pack('>I', len(data)) + data)
except Exception as e:
logging.error(f"发送消息失败 {self.client_address}: {e}")
self.disconnect()
def send_error(self, msg):
self.send_message({'type': 'error', 'message': msg})
# ── 主循环 ────────────────────────────────────────────
def run(self):
logging.info(f"客户端连接: {self.client_address}")
while self.running:
try:
header = self._recv_exact(4)
if not header:
break
msg_len = struct.unpack('>I', header)[0]
if msg_len > 10 * 1024 * 1024: # 拒绝超大包
break
data = self._recv_exact(msg_len)
if not data:
break
try:
self.handle_message(json.loads(data.decode('utf-8')))
except json.JSONDecodeError:
self.send_error("消息格式错误")
except socket.timeout:
# 超时发心跳
self.send_message({'type': 'heartbeat', 'timestamp': time.time()})
except Exception as e:
if self.running:
logging.error(f"处理消息出错: {e}")
break
self.disconnect()
# ── 消息路由 ──────────────────────────────────────────
def handle_message(self, message):
t = message.get('type')
handlers = {
'login': self.handle_login,
'register': self.handle_register,
'chat': self.handle_chat,
'get_users': lambda m: self.handle_get_users(),
'search_users': self.handle_search_users,
'add_friend': self.handle_add_friend,
'remove_friend': self.handle_remove_friend,
'get_friends': lambda m: self.handle_get_friends(),
'get_history': self.handle_get_history,
'create_group': self.handle_create_group,
'get_groups': lambda m: self.handle_get_groups(),
'get_all_groups': lambda m: self.handle_get_all_groups(),
'join_group': self.handle_join_group,
'group_chat': self.handle_group_chat,
'get_group_history': self.handle_get_group_history,
'get_group_members': self.handle_get_group_members,
'leave_group': self.handle_leave_group,
'invite_to_group': self.handle_invite_to_group,
'search_messages': self.handle_search_messages,
'get_all_history': self.handle_get_all_history,
'get_recent_history': self.handle_get_recent_history,
'heartbeat': lambda m: self.send_message({'type': 'heartbeat_response', 'timestamp': time.time()}),
'change_username': self.handle_change_username,
'change_nickname': self.handle_change_nickname,
'change_password': self.handle_change_password,
'chat_image': self.handle_chat_image,
'group_chat_image': self.handle_group_chat_image,
}
handler = handlers.get(t)
if handler:
handler(message)
else:
self.send_error(f"未知消息类型: {t}")
# ── 登录 / 注册 ───────────────────────────────────────
def handle_login(self, msg):
username = msg.get('username', '').strip()
password = msg.get('password', '').strip()
if not username or not password:
self.send_message({'type': 'login_response', 'success': False, 'message': '用户名和密码不能为空'})
return
# 踢掉同账号的旧连接
with self.server._users_lock:
for uid, info in list(self.server.online_users.items()):
if info['username'] == username:
info['handler'].send_message({'type': 'error', 'message': '账号在其他地方登录'})
info['handler'].disconnect()
break
success, msg_text, user_info = self.server.db.user_login(username, password)
if success:
self.user_id = user_info['id']
self.username = username
with self.server._users_lock:
self.server.online_users[self.user_id] = {
'handler': self,
'username': username,
'nickname': user_info.get('nickname', username),
'login_time': beijing_now_str()
}
online_count = len(self.server.online_users)
self.send_message({
'type': 'login_response', 'success': True,
'message': '登录成功', 'user_info': user_info,
'online_count': online_count
})
self.broadcast_user_status(True)
logging.info(f"用户登录: {username}")
else:
self.send_message({'type': 'login_response', 'success': False, 'message': msg_text})
def handle_register(self, msg):
username = msg.get('username', '').strip()
password = msg.get('password', '').strip()
nickname = msg.get('nickname', '').strip()
if not username or not password:
self.send_message({'type': 'register_response', 'success': False, 'message': '用户名和密码不能为空'})
return
success, text = self.server.db.user_register(username, password, nickname or username)
self.send_message({'type': 'register_response', 'success': success, 'message': text})
if success:
logging.info(f"新用户注册: {username}")
# ── 个人资料修改 ──────────────────────────────────────
def handle_change_username(self, msg):
if not self._require_login():
return
new_username = msg.get('new_username', '').strip()
if not new_username:
self.send_message({'type': 'change_username_response', 'success': False, 'message': '新用户名不能为空'})
return
if len(new_username) < 4 or len(new_username) > 20:
self.send_message({'type': 'change_username_response', 'success': False, 'message': '用户名长度需4-20位'})
return
if not all(c.isalnum() or c == '_' for c in new_username):
self.send_message({'type': 'change_username_response', 'success': False, 'message': '用户名只能包含字母、数字和下划线'})
return
success, msg_text = self.server.db.change_username(self.user_id, new_username)
if success:
old_username = self.username
self.username = new_username
with self.server._users_lock:
if self.user_id in self.server.online_users:
self.server.online_users[self.user_id]['username'] = new_username
self.send_message({'type': 'change_username_response', 'success': True,
'message': msg_text, 'new_username': new_username})
logging.info(f'用户改名: {old_username} -> {new_username}')
else:
self.send_message({'type': 'change_username_response', 'success': False, 'message': msg_text})
def handle_change_nickname(self, msg):
if not self._require_login():
return
new_nickname = msg.get('new_nickname', '').strip()
if not new_nickname:
self.send_message({'type': 'change_nickname_response', 'success': False, 'message': '昵称不能为空'})
return
success, msg_text = self.server.db.change_nickname(self.user_id, new_nickname)
if success:
with self.server._users_lock:
if self.user_id in self.server.online_users:
self.server.online_users[self.user_id]['nickname'] = new_nickname
self.send_message({'type': 'change_nickname_response', 'success': success, 'message': msg_text,
'new_nickname': new_nickname})
def handle_change_password(self, msg):
if not self._require_login():
return
old_password = msg.get('old_password', '').strip()
new_password = msg.get('new_password', '').strip()
if not old_password or not new_password:
self.send_message({'type': 'change_password_response', 'success': False, 'message': '密码不能为空'})
return
if len(new_password) < 6 or len(new_password) > 20:
self.send_message({'type': 'change_password_response', 'success': False, 'message': '密码长度需6-20位'})
return
success, msg_text = self.server.db.change_password(self.user_id, old_password, new_password)
self.send_message({'type': 'change_password_response', 'success': success, 'message': msg_text})
# ── 私聊 ──────────────────────────────────────────────
def handle_chat(self, msg):
if not self._require_login():
return
receiver = msg.get('receiver', '').strip()
content = msg.get('content', '').strip()
if not receiver or not content:
self.send_error("接收方或消息内容不能为空")
return
receiver_info = self.server.db.get_user_by_username(receiver)
if not receiver_info:
self.send_error("用户不存在")
return
self.server.db.save_message(self.user_id, receiver_info['id'], content)
chat_msg = {
'type': 'chat',
'sender': self.username,
'receiver': receiver,
'content': content,
'timestamp': beijing_now_str()
}
# 只转发给接收方(如在线),发送方由客户端本地显示
rid = receiver_info['id']
if rid in self.server.online_users:
self.server.online_users[rid]['handler'].send_message(chat_msg)
# ── 用户列表 / 搜索 ───────────────────────────────────
def handle_get_users(self):
if not self._require_login():
return
users = self.server.db.get_all_users(self.user_id)
with self.server._users_lock:
online_ids = set(self.server.online_users.keys())
for u in users:
u['is_online'] = u['id'] in online_ids
self.send_message({'type': 'users_list', 'users': users})
def handle_search_users(self, msg):
if not self._require_login():
return
keyword = msg.get('keyword', '').strip()
users = self.server.db.search_users(keyword, self.user_id) if keyword else self.server.db.get_all_users(self.user_id)
with self.server._users_lock:
online_ids = set(self.server.online_users.keys())
for u in users:
u['is_online'] = u['id'] in online_ids
self.send_message({'type': 'users_list', 'users': users})
# ── 好友管理 ──────────────────────────────────────────
def handle_add_friend(self, msg):
if not self._require_login():
return
friend_username = msg.get('username', '').strip()
if not friend_username:
self.send_error("请输入用户名")
return
success, text = self.server.db.add_friend(self.user_id, friend_username)
self.send_message({'type': 'add_friend_response', 'success': success, 'message': text})
def handle_remove_friend(self, msg):
if not self._require_login():
return
friend_id = msg.get('friend_id')
if not friend_id:
self.send_error("缺少好友ID")
return
success, text = self.server.db.remove_friend(self.user_id, friend_id)
self.send_message({'type': 'remove_friend_response', 'success': success, 'message': text})
def handle_get_friends(self):
if not self._require_login():
return
friends = self.server.db.get_friends(self.user_id)
with self.server._users_lock:
online_ids = set(self.server.online_users.keys())
for f in friends:
f['is_online'] = f['id'] in online_ids
self.send_message({'type': 'friends_list', 'friends': friends})
# ── 聊天记录 ──────────────────────────────────────────
def handle_get_history(self, msg):
if not self._require_login():
return
friend_username = msg.get('username', '').strip()
limit = min(int(msg.get('limit', 50)), 200)
friend_info = self.server.db.get_user_by_username(friend_username)
if not friend_info:
self.send_error("用户不存在")
return
history = self.server.db.get_chat_history(self.user_id, friend_info['id'], limit)
formatted = [{'sender': m['username'], 'content': m['content'], 'timestamp': m['created_at']}
for m in history]
self.send_message({'type': 'chat_history', 'friend': friend_username, 'history': formatted})
def handle_get_group_history(self, msg):
if not self._require_login():
return
group_id = msg.get('group_id')
limit = min(int(msg.get('limit', 50)), 200)
if not group_id:
self.send_error("缺少群组ID")
return
history = self.server.db.get_group_history(group_id, limit)
formatted = [{'sender': m['username'], 'content': m['content'], 'timestamp': m['created_at']}
for m in history]
self.send_message({'type': 'group_history', 'group_id': group_id, 'history': formatted})
# ── 群组 ──────────────────────────────────────────────
def handle_create_group(self, msg):
if not self._require_login():
return
group_name = msg.get('group_name', '').strip()
if not group_name:
self.send_error("请输入群组名称")
return
success, result = self.server.db.create_group(group_name, self.user_id)
if success:
self.send_message({'type': 'create_group_response', 'success': True,
'message': f"群组'{group_name}'创建成功", 'group_id': result})
else:
self.send_message({'type': 'create_group_response', 'success': False,
'message': f"创建失败: {result}"})
def handle_get_groups(self):
if not self._require_login():
return
groups = self.server.db.get_user_groups(self.user_id)
self.send_message({'type': 'groups_list', 'groups': groups})
def handle_get_all_groups(self):
if not self._require_login():
return
groups = self.server.db.get_all_groups()
self.send_message({'type': 'all_groups_list', 'groups': groups})
def handle_join_group(self, msg):
if not self._require_login():
return
group_id = msg.get('group_id')
if not group_id:
self.send_error("缺少群组ID")
return
success, text = self.server.db.join_group(group_id, self.user_id)
self.send_message({'type': 'join_group_response', 'success': success, 'message': text,
'group_id': group_id})
def handle_group_chat(self, msg):
if not self._require_login():
return
group_id = msg.get('group_id')
content = msg.get('content', '').strip()
if not group_id or not content:
self.send_error("群组ID或消息内容不能为空")
return
self.server.db.save_message(self.user_id, None, content, group_id=group_id)
members = self.server.db.get_group_members(group_id)
group_msg = {
'type': 'group_chat',
'sender': self.username,
'group_id': group_id,
'content': content,
'timestamp': beijing_now_str()
}
# 只转发给其他在线成员,发送方由客户端本地显示
for member in members:
mid = member['id']
if mid != self.user_id and mid in self.server.online_users:
self.server.online_users[mid]['handler'].send_message(group_msg)
def handle_chat_image(self, msg):
if not self._require_login():
return
receiver = msg.get('receiver', '').strip()
filename = msg.get('filename', 'image.png')
file_data_b64 = msg.get('data', '')
if not receiver or not file_data_b64:
self.send_error("接收方或图片数据不能为空")
return
receiver_info = self.server.db.get_user_by_username(receiver)
if not receiver_info:
self.send_error("用户不存在")
return
self.server.db.save_message(self.user_id, receiver_info['id'], filename, msg_type='image')
image_msg = {
'type': 'chat_image',
'sender': self.username,
'receiver': receiver,
'filename': filename,
'data': file_data_b64,
'timestamp': beijing_now_str()
}
rid = receiver_info['id']
if rid in self.server.online_users:
self.server.online_users[rid]['handler'].send_message(image_msg)
def handle_group_chat_image(self, msg):
if not self._require_login():
return
group_id = msg.get('group_id')
filename = msg.get('filename', 'image.png')
file_data_b64 = msg.get('data', '')
if not group_id or not file_data_b64:
self.send_error("群组ID或图片数据不能为空")
return
self.server.db.save_message(self.user_id, None, filename, msg_type='image', group_id=group_id)
members = self.server.db.get_group_members(group_id)
image_msg = {
'type': 'group_chat_image',
'sender': self.username,
'group_id': group_id,
'filename': filename,
'data': file_data_b64,
'timestamp': beijing_now_str()
}
for member in members:
mid = member['id']
if mid != self.user_id and mid in self.server.online_users:
self.server.online_users[mid]['handler'].send_message(image_msg)
def handle_get_group_members(self, msg):
if not self._require_login():
return
group_id = msg.get('group_id')
if not group_id:
self.send_error("缺少群组ID")
return
members = self.server.db.get_group_members(group_id)
with self.server._users_lock:
online_ids = set(self.server.online_users.keys())
for m in members:
m['is_online'] = m['id'] in online_ids
self.send_message({'type': 'group_members', 'group_id': group_id, 'members': members})
# ── 群组管理(续)──────────────────────────────────────
def handle_leave_group(self, msg):
if not self._require_login():
return
group_id = msg.get('group_id')
if not group_id:
self.send_error("缺少群组ID")
return
success, text = self.server.db.leave_group(group_id, self.user_id)
self.send_message({'type': 'leave_group_response', 'success': success, 'message': text, 'group_id': group_id})
def handle_invite_to_group(self, msg):
if not self._require_login():
return
group_id = msg.get('group_id')
username = msg.get('username', '').strip()
if not group_id or not username:
self.send_error("缺少群组ID或用户名")
return
success, text = self.server.db.invite_to_group(group_id, username)
self.send_message({'type': 'invite_to_group_response', 'success': success, 'message': text, 'group_id': group_id})
# ── 消息搜索 ───────────────────────────────────────────
def handle_search_messages(self, msg):
if not self._require_login():
return
keyword = msg.get('keyword', '').strip()
chat_type = msg.get('chat_type', 'private')
target_id = msg.get('target_id')
if not keyword:
self.send_error("请输入搜索关键词")
return
results = self.server.db.search_messages(self.user_id, keyword, chat_type, target_id)
formatted = [{'sender': m['username'], 'content': m['content'], 'timestamp': m['created_at']}
for m in results]
self.send_message({'type': 'search_results', 'keyword': keyword, 'results': formatted,
'chat_type': chat_type, 'target_id': target_id})
def handle_get_all_history(self, msg):
if not self._require_login():
return
limit = min(int(msg.get('limit', 200)), 500)
results = self.server.db.get_all_user_history(self.user_id, limit)
formatted = [{
'sender': m['sender_name'],
'content': m['content'],
'timestamp': m['created_at'],
'chat_type': m['chat_type'],
'target_name': m['target_name'],
'group_id': m['group_id'],
} for m in results]
self.send_message({'type': 'all_history', 'history': formatted})
def handle_get_recent_history(self, msg):
if not self._require_login():
return
chat_type = msg.get('chat_type', 'private')
target_id = msg.get('target_id')
limit = min(int(msg.get('limit', 50)), 200)
results = self.server.db.get_recent_history(self.user_id, chat_type, target_id, limit)
formatted = [{'sender': m['username'], 'content': m['content'], 'timestamp': m['created_at']}
for m in results]
self.send_message({'type': 'recent_history', 'chat_type': chat_type,
'target_id': target_id, 'history': formatted})
# ── 工具 ──────────────────────────────────────────────
def _require_login(self):
if not self.user_id:
self.send_error("请先登录")
return False
return True
def broadcast_user_status(self, is_online):
with self.server._users_lock:
online_count = len(self.server.online_users)
targets = [(uid, info) for uid, info in self.server.online_users.items()
if uid != self.user_id]
msg = {
'type': 'user_status',
'user_id': self.user_id,
'username': self.username,
'is_online': is_online,
'online_count': online_count
}
for uid, info in targets:
try:
info['handler'].send_message(msg)
except Exception:
pass
def disconnect(self):
if not self.running:
return
self.running = False
with self.server._users_lock:
if self.user_id and self.user_id in self.server.online_users:
del self.server.online_users[self.user_id]
should_broadcast = bool(self.username)
else:
should_broadcast = False
if should_broadcast:
self.broadcast_user_status(False)
try:
self.client_socket.close()
except Exception:
pass
logging.info(f"客户端断开: {self.client_address} ({self.username})")
class ChatServer:
def __init__(self, host='0.0.0.0', port=8888):
self.host = host
self.port = port
self.server_socket = None
self.running = False
self.client_threads = []
self.db = Database()
self.online_users = {} # {user_id: {handler, username, nickname, login_time}}
self._users_lock = threading.Lock()
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[logging.FileHandler('server.log', encoding='utf-8'),
logging.StreamHandler()])
def start(self):
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server_socket.bind((self.host, self.port))
self.server_socket.listen(10)
self.server_socket.settimeout(1)
self.running = True
logging.info(f"服务器启动: {self.host}:{self.port}")
try:
while self.running:
try:
client_socket, addr = self.server_socket.accept()
handler = ClientHandler(client_socket, addr, self)
handler.start()
self.client_threads.append(handler)
# 清理已结束的线程
self.client_threads = [t for t in self.client_threads if t.is_alive()]
except socket.timeout:
continue
except Exception as e:
if self.running:
logging.error(f"接受连接出错: {e}")
finally:
self.stop()
def stop(self):
self.running = False
for t in self.client_threads:
t.disconnect()
if self.server_socket:
self.server_socket.close()
logging.info("服务器已停止")