zhentao 2 weeks ago
parent
commit
7472b925e8

+ 70 - 0
src/main/java/com/zhentao/information/cache/ChannelCache.java

@@ -0,0 +1,70 @@
+package com.zhentao.information.cache;
+
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.stereotype.Component;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * Channel缓存管理类
+ * 用于管理用户Channel
+ */
+@Slf4j
+@Component
+public class ChannelCache {
+    
+    /**
+     * 用户ID和Channel的映射关系
+     * key: 用户ID
+     * value: Channel上下文
+     */
+    private static final Map<String, ChannelHandlerContext> USER_CHANNEL_MAP = new ConcurrentHashMap<>();
+
+    /**
+     * 添加用户Channel映射
+     * @param userId 用户ID
+     * @param ctx Channel上下文
+     */
+    public void addCache(String userId, ChannelHandlerContext ctx) {
+        USER_CHANNEL_MAP.put(userId, ctx);
+        log.info("用户 {} 的Channel已添加到缓存", userId);
+    }
+
+    /**
+     * 获取用户的Channel
+     * @param userId 用户ID
+     * @return Channel上下文
+     */
+    public ChannelHandlerContext getCache(String userId) {
+        return USER_CHANNEL_MAP.get(userId);
+    }
+
+    /**
+     * 移除用户Channel映射
+     * @param userId 用户ID
+     */
+    public void removeCache(String userId) {
+        USER_CHANNEL_MAP.remove(userId);
+        log.info("用户 {} 的Channel已从缓存移除", userId);
+    }
+
+    /**
+     * 获取所有用户Channel映射
+     * @return 用户Channel映射Map
+     */
+    public Map<String, ChannelHandlerContext> getAllCache() {
+        return USER_CHANNEL_MAP;
+    }
+
+    /**
+     * 判断用户是否在线
+     * @param userId 用户ID
+     * @return 是否在线
+     */
+    public boolean isOnline(String userId) {
+        return USER_CHANNEL_MAP.containsKey(userId);
+    }
+} 

+ 99 - 0
src/main/java/com/zhentao/information/config/NettyConfig.java

@@ -0,0 +1,99 @@
+package com.zhentao.information.config;
+
+import com.zhentao.information.handler.WebSocketHandler;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.handler.codec.http.HttpObjectAggregator;
+import io.netty.handler.codec.http.HttpServerCodec;
+import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
+import io.netty.handler.stream.ChunkedWriteHandler;
+import io.netty.handler.timeout.IdleStateHandler;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+
+import javax.annotation.Resource;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Netty服务器配置类
+ * 配置WebSocket服务器的启动参数和处理器链
+ */
+@Slf4j
+@Configuration
+public class NettyConfig {
+
+    /**
+     * WebSocket服务器端口
+     */
+    @Value("${netty.websocket.port}")
+    private int port;
+
+    /**
+     * WebSocket消息处理器
+     */
+    @Resource
+    private WebSocketHandler webSocketHandler;
+
+    /**
+     * 配置并启动Netty服务器
+     * @return ServerBootstrap实例
+     */
+    @Bean
+    public ServerBootstrap serverBootstrap() {
+        // 创建主从线程组
+        // bossGroup用于接收客户端连接
+        EventLoopGroup bossGroup = new NioEventLoopGroup(1);
+        // workerGroup用于处理客户端数据
+        EventLoopGroup workerGroup = new NioEventLoopGroup();
+        
+        // 创建服务器启动对象
+        ServerBootstrap bootstrap = new ServerBootstrap();
+        bootstrap.group(bossGroup, workerGroup)
+                // 设置服务器通道实现
+                .channel(NioServerSocketChannel.class)
+                // 设置线程队列等待连接个数
+                .option(ChannelOption.SO_BACKLOG, 128)
+                // 设置保持活动连接状态
+                .childOption(ChannelOption.SO_KEEPALIVE, true)
+                // 禁用Nagle算法,减少延迟
+                .childOption(ChannelOption.TCP_NODELAY, true)
+                // 设置处理器
+                .childHandler(new ChannelInitializer<SocketChannel>() {
+                    @Override
+                    protected void initChannel(SocketChannel ch) {
+                        // 获取管道
+                        ch.pipeline()
+                                // HTTP编解码器
+                                .addLast(new HttpServerCodec())
+                                // 支持大数据流
+                                .addLast(new ChunkedWriteHandler())
+                                // HTTP消息聚合器
+                                .addLast(new HttpObjectAggregator(65536))
+                                // 心跳检测,60秒没有收到消息就触发
+                                .addLast(new IdleStateHandler(60, 0, 0, TimeUnit.SECONDS))
+                                // WebSocket协议处理器
+                                .addLast(new WebSocketServerProtocolHandler("/ws", null, true))
+                                // 自定义消息处理器
+                                .addLast(webSocketHandler);
+                    }
+                });
+        
+        try {
+            // 绑定端口并启动服务器
+            bootstrap.bind(port).sync();
+            log.info("Netty WebSocket服务器启动成功,端口:{}", port);
+        } catch (InterruptedException e) {
+            log.error("Netty WebSocket服务器启动失败", e);
+            Thread.currentThread().interrupt();
+        }
+        
+        return bootstrap;
+    }
+} 

+ 40 - 0
src/main/java/com/zhentao/information/handler/HeartbeatHandler.java

@@ -0,0 +1,40 @@
+package com.zhentao.information.handler;
+
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.timeout.IdleState;
+import io.netty.handler.timeout.IdleStateEvent;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.stereotype.Component;
+
+/**
+ * 心跳处理器
+ * 处理客户端的心跳检测
+ */
+@Slf4j
+@Component
+@ChannelHandler.Sharable
+public class HeartbeatHandler extends ChannelInboundHandlerAdapter {
+
+    /**
+     * 处理用户事件
+     * 当触发IdleStateEvent时调用
+     * @param ctx Channel上下文
+     * @param evt 事件对象
+     */
+    @Override
+    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
+        if (evt instanceof IdleStateEvent) {
+            IdleStateEvent event = (IdleStateEvent) evt;
+            
+            // 如果是读空闲事件
+            if (event.state() == IdleState.READER_IDLE) {
+                log.info("读空闲,关闭连接:{}", ctx.channel().id().asLongText());
+                ctx.close();
+            }
+        } else {
+            super.userEventTriggered(ctx, evt);
+        }
+    }
+} 

+ 112 - 0
src/main/java/com/zhentao/information/service/WebSocketService.java

@@ -0,0 +1,112 @@
+package com.zhentao.information.service;
+
+import com.alibaba.fastjson.JSON;
+import com.zhentao.information.cache.ChannelCache;
+import com.zhentao.information.entity.Message;
+import com.zhentao.tool.TokenUtils;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.stereotype.Service;
+
+import javax.annotation.Resource;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * WebSocket服务类
+ * 处理WebSocket连接、消息发送等业务逻辑
+ */
+@Slf4j
+@Service
+public class WebSocketService {
+
+    @Resource
+    private ChannelCache channelCache;
+
+    // 存储用户token的Map
+    private final Map<String, String> userTokenMap = new ConcurrentHashMap<>();
+
+    /**
+     * 存储用户token
+     * @param userId 用户ID
+     * @param token 用户token
+     */
+    public void storeUserToken(String userId, String token) {
+        userTokenMap.put(userId, token);
+        log.info("用户 {} 的token已存储", userId);
+    }
+
+    /**
+     * 获取用户token
+     * @param userId 用户ID
+     * @return 用户token
+     */
+    public String getUserToken(String userId) {
+        return userTokenMap.get(userId);
+    }
+
+    /**
+     * 处理用户登录
+     * @param token 用户token
+     * @param ctx Channel上下文
+     */
+    public void handleUserLogin(String token, ChannelHandlerContext ctx) {
+        String userId = TokenUtils.getUserIdFromToken(token);
+        if (userId != null) {
+            // 验证token是否与存储的token匹配
+            String storedToken = userTokenMap.get(userId);
+            if (storedToken != null && storedToken.equals(token)) {
+                // 将用户ID和Channel绑定
+                channelCache.addCache(userId, ctx);
+                log.info("用户 {} 连接成功", userId);
+                
+                // 发送连接成功消息
+                Message response = new Message();
+                response.setType("connect_success");
+                response.setContent("连接成功");
+                ctx.channel().writeAndFlush(new TextWebSocketFrame(JSON.toJSONString(response)));
+            } else {
+                log.error("用户 {} 的token验证失败", userId);
+                ctx.close();
+            }
+        } else {
+            log.error("无效的token");
+            ctx.close();
+        }
+    }
+
+    /**
+     * 发送消息给指定用户
+     * @param userId 接收者用户ID
+     * @param message 消息内容
+     * @return 是否发送成功
+     */
+    public boolean sendMessageToUser(String userId, Message message) {
+        ChannelHandlerContext ctx = channelCache.getCache(userId);
+        if (ctx != null) {
+            ctx.channel().writeAndFlush(new TextWebSocketFrame(JSON.toJSONString(message)));
+            return true;
+        }
+        return false;
+    }
+
+    /**
+     * 广播消息给所有在线用户
+     * @param message 消息内容
+     */
+    public void broadcastMessage(Message message) {
+        channelCache.getAllCache().forEach((userId, ctx) -> {
+            ctx.channel().writeAndFlush(new TextWebSocketFrame(JSON.toJSONString(message)));
+        });
+    }
+
+    /**
+     * 检查用户是否在线
+     * @param userId 用户ID
+     * @return 是否在线
+     */
+    public boolean isUserOnline(String userId) {
+        return channelCache.getCache(userId) != null;
+    }
+}