You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

409 lines
15 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package server;
import com.google.gson.Gson;
import java.io.*;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
/**
* WebSocket客户端处理器
*/
public class WebSocketClient implements Runnable {
private Socket socket;
private WebSocketServer server;
private InputStream input;
private OutputStream output;
private String username;
private volatile boolean running;
private Gson gson;
public WebSocketClient(Socket socket, WebSocketServer server) {
this.socket = socket;
this.server = server;
this.running = true;
this.gson = new Gson();
try {
// 增加Socket缓冲区大小以支持大文件传输
socket.setSendBufferSize(1024 * 1024); // 1MB
socket.setReceiveBufferSize(1024 * 1024); // 1MB
socket.setTcpNoDelay(true); // 禁用Nagle算法提高实时性
} catch (Exception e) {
System.err.println("设置Socket参数失败: " + e.getMessage());
}
}
@Override
public void run() {
try {
input = socket.getInputStream();
output = socket.getOutputStream();
System.out.println("WebSocket客户端处理器启动");
StringBuilder messageBuffer = new StringBuilder();
while (running) {
try {
FrameData frame = readWebSocketFrame();
if (frame == null) {
System.out.println("读取到 null 帧,连接可能已关闭");
break;
}
// 累积消息片段
if (frame.payload != null && !frame.payload.isEmpty()) {
messageBuffer.append(frame.payload);
}
// 如果是最后一个片段,处理完整消息
if (frame.fin) {
String completeMessage = messageBuffer.toString();
messageBuffer.setLength(0); // 清空缓冲区
if (!completeMessage.isEmpty()) {
System.out.println("收到完整消息,长度: " + completeMessage.length() + " 字符");
handleMessage(completeMessage);
}
}
} catch (IOException e) {
if (running) {
System.err.println("读取消息时出错: " + e.getMessage());
}
break;
}
}
} catch (IOException e) {
if (running) {
System.out.println("客户端 " + username + " 连接异常: " + e.getMessage());
}
} finally {
cleanup();
}
}
// 帧数据结构
private static class FrameData {
boolean fin;
String payload;
FrameData(boolean fin, String payload) {
this.fin = fin;
this.payload = payload;
}
}
private FrameData readWebSocketFrame() throws IOException {
int b = input.read();
if (b == -1) return null;
boolean fin = (b & 0x80) != 0;
int opcode = b & 0x0F;
System.out.println("读取帧: FIN=" + fin + ", opcode=" + opcode);
// 处理控制帧
if (opcode == 8) { // Close frame
System.out.println("收到关闭帧");
running = false;
return null;
}
if (opcode == 9) { // Ping frame
System.out.println("收到 Ping 帧,发送 Pong");
sendPong();
return readWebSocketFrame(); // 继续读取下一帧
}
if (opcode == 10) { // Pong frame
System.out.println("收到 Pong 帧");
return readWebSocketFrame(); // 继续读取下一帧
}
// 处理文本帧 (opcode = 1) 和继续帧 (opcode = 0)
if (opcode != 1 && opcode != 0) {
System.err.println("未知的 opcode: " + opcode + ",跳过");
// 读取并丢弃这个帧
b = input.read();
if (b == -1) return null;
long payloadLength = b & 0x7F;
if (payloadLength == 126) {
input.read(); input.read();
} else if (payloadLength == 127) {
for (int i = 0; i < 8; i++) input.read();
}
return readWebSocketFrame(); // 继续读取下一帧
}
b = input.read();
if (b == -1) return null;
boolean masked = (b & 0x80) != 0;
long payloadLength = b & 0x7F;
if (payloadLength == 126) {
int b1 = input.read();
int b2 = input.read();
if (b1 == -1 || b2 == -1) return null;
payloadLength = ((b1 & 0xFF) << 8) | (b2 & 0xFF);
} else if (payloadLength == 127) {
payloadLength = 0;
for (int i = 0; i < 8; i++) {
int readByte = input.read();
if (readByte == -1) return null;
payloadLength = (payloadLength << 8) | (readByte & 0xFF);
}
}
System.out.println("Payload 长度: " + payloadLength + ", masked: " + masked);
// 检查payload长度是否合理最大20MB
if (payloadLength > 20 * 1024 * 1024) {
System.err.println("Payload太大: " + payloadLength + " 字节");
running = false;
return null;
}
if (payloadLength == 0) {
System.out.println("收到空 payload");
return new FrameData(fin, "");
}
byte[] maskingKey = new byte[4];
if (masked) {
int bytesRead = input.read(maskingKey);
if (bytesRead != 4) {
System.err.println("读取 masking key 失败");
return null;
}
}
// 分块读取payload
byte[] payload = new byte[(int) payloadLength];
int totalRead = 0;
while (totalRead < payloadLength) {
int bytesRead = input.read(payload, totalRead, (int)(payloadLength - totalRead));
if (bytesRead == -1) {
System.err.println("读取payload时连接断开已读取: " + totalRead + "/" + payloadLength);
return null;
}
totalRead += bytesRead;
}
System.out.println("成功读取 payload: " + totalRead + " 字节");
// 解除 masking
if (masked) {
for (int i = 0; i < payload.length; i++) {
payload[i] = (byte)(payload[i] ^ maskingKey[i % 4]);
}
}
String result = new String(payload, StandardCharsets.UTF_8);
System.out.println("解码后字符串长度: " + result.length() + " 字符");
return new FrameData(fin, result);
}
private void sendPong() throws IOException {
byte[] pongFrame = new byte[2];
pongFrame[0] = (byte) 0x8A; // FIN + Pong
pongFrame[1] = 0; // No payload
output.write(pongFrame);
output.flush();
}
private void handleMessage(String jsonMessage) {
try {
// 检查消息是否为空
if (jsonMessage == null || jsonMessage.trim().isEmpty()) {
System.err.println("收到空消息,忽略");
return;
}
@SuppressWarnings("unchecked")
Map<String, Object> message = gson.fromJson(jsonMessage, Map.class);
String type = (String) message.get("type");
if (type == null) {
System.err.println("消息缺少 type 字段");
return;
}
System.out.println("处理消息类型: " + type);
if ("LOGIN".equals(type)) {
username = (String) message.get("sender");
System.out.println("用户登录: " + username);
// 尝试添加客户端,检查用户名是否已存在
boolean success = server.addClient(username, this);
Map<String, Object> response = new HashMap<>();
if (success) {
// 登录成功
response.put("type", "LOGIN_SUCCESS");
response.put("content", "登录成功");
} else {
// 用户名已存在
response.put("type", "LOGIN_FAILED");
response.put("content", "用户名已被使用,请更换用户名");
running = false; // 关闭连接
}
sendMessage(gson.toJson(response));
} else if ("LOGOUT".equals(type)) {
running = false;
} else if ("PRIVATE_MSG".equals(type)) {
String sender = (String) message.get("sender");
String receiver = (String) message.get("receiver");
String content = (String) message.get("content");
System.out.println("转发消息: " + sender + " -> " + receiver);
server.sendPrivateMessage(sender, receiver, content);
} else if ("FILE".equals(type) || "IMAGE".equals(type) ||
"VIDEO".equals(type) || "AUDIO".equals(type) || "VOICE".equals(type)) {
// 转发文件和多媒体消息
String sender = (String) message.get("sender");
String receiver = (String) message.get("receiver");
String fileName = (String) message.get("fileName");
Object fileSize = message.get("fileSize");
Object duration = message.get("duration");
System.out.println("=== 转发多媒体消息 ===");
System.out.println("类型: " + type);
System.out.println("发送者: " + sender);
System.out.println("接收者: " + receiver);
System.out.println("文件名: " + fileName);
System.out.println("文件大小: " + fileSize);
if (duration != null) {
System.out.println("时长: " + duration + " 秒");
}
// 添加时间戳
message.put("timestamp", System.currentTimeMillis());
// 转发消息
server.forwardMediaMessage(message);
System.out.println("消息已转发");
}
} catch (com.google.gson.JsonSyntaxException e) {
System.err.println("JSON 解析失败: " + e.getMessage());
System.err.println("消息内容前100字符: " +
(jsonMessage != null && jsonMessage.length() > 100 ?
jsonMessage.substring(0, 100) + "..." : jsonMessage));
} catch (Exception e) {
System.err.println("处理消息失败: " + e.getMessage());
e.printStackTrace();
}
}
public synchronized void sendMessage(String message) {
try {
byte[] payload = message.getBytes(StandardCharsets.UTF_8);
int payloadLength = payload.length;
System.out.println("准备发送消息: " + payloadLength + " 字节");
// 对于大消息超过64KB分片发送
int maxFrameSize = 65536; // 64KB
if (payloadLength <= maxFrameSize) {
// 小消息,一次发送
sendFrame(payload, 0, payloadLength, true);
} else {
// 大消息,分片发送
System.out.println("消息过大,分片发送");
int offset = 0;
int fragmentCount = 0;
while (offset < payloadLength) {
int length = Math.min(maxFrameSize, payloadLength - offset);
boolean isFinal = (offset + length >= payloadLength);
sendFrame(payload, offset, length, isFinal);
offset += length;
fragmentCount++;
if (fragmentCount % 10 == 0) {
System.out.println("已发送: " + offset + "/" + payloadLength + " 字节");
}
}
System.out.println("分片发送完成,共 " + fragmentCount + " 个片段");
}
} catch (IOException e) {
System.err.println("发送消息失败: " + e.getMessage());
e.printStackTrace();
running = false;
}
}
private void sendFrame(byte[] payload, int offset, int length, boolean isFinal) throws IOException {
// 构建WebSocket帧头
ByteArrayOutputStream frameHeader = new ByteArrayOutputStream();
// 第一个字节: FIN + opcode
// 如果是第一个片段opcode = 1 (text),否则 opcode = 0 (continuation)
int firstByte;
if (offset == 0) {
// 第一个片段
firstByte = isFinal ? 0x81 : 0x01; // FIN=1/0, opcode=1
} else {
// 后续片段
firstByte = isFinal ? 0x80 : 0x00; // FIN=1/0, opcode=0
}
frameHeader.write(firstByte);
// 第二个字节及后续: payload length (不使用 mask)
if (length <= 125) {
frameHeader.write(length);
} else if (length <= 65535) {
frameHeader.write(126);
frameHeader.write((length >> 8) & 0xFF);
frameHeader.write(length & 0xFF);
} else {
frameHeader.write(127);
frameHeader.write((int)((length >> 56) & 0xFF));
frameHeader.write((int)((length >> 48) & 0xFF));
frameHeader.write((int)((length >> 40) & 0xFF));
frameHeader.write((int)((length >> 32) & 0xFF));
frameHeader.write((int)((length >> 24) & 0xFF));
frameHeader.write((int)((length >> 16) & 0xFF));
frameHeader.write((int)((length >> 8) & 0xFF));
frameHeader.write((int)(length & 0xFF));
}
// 发送帧头
output.write(frameHeader.toByteArray());
// 分块发送payload以避免缓冲区溢出
int chunkSize = 8192; // 8KB chunks
int sent = 0;
while (sent < length) {
int chunkLength = Math.min(chunkSize, length - sent);
output.write(payload, offset + sent, chunkLength);
sent += chunkLength;
}
output.flush();
}
private void cleanup() {
try {
if (username != null) {
server.removeClient(username);
}
if (input != null) input.close();
if (output != null) output.close();
if (socket != null && !socket.isClosed()) socket.close();
} catch (IOException e) {
System.err.println("清理资源失败: " + e.getMessage());
}
}
}