zhentao 2 weeks ago
parent
commit
66739ad9b7

+ 0 - 29
src/main/java/com/zhentao/config/NettyConfig.java

@@ -1,29 +0,0 @@
-package com.zhentao.config;
-
-import io.netty.bootstrap.ServerBootstrap;
-import io.netty.channel.ChannelOption;
-import io.netty.channel.nio.NioEventLoopGroup;
-import io.netty.channel.socket.nio.NioServerSocketChannel;
-import org.springframework.beans.factory.annotation.Value;
-import org.springframework.context.annotation.Bean;
-import org.springframework.context.annotation.Configuration;
-
-/**
- * Netty服务器配置类
- */
-@Configuration
-public class NettyConfig {
-
-    @Value("${netty.port:8888}")
-    private int port;
-
-    @Bean
-    public ServerBootstrap serverBootstrap() {
-        ServerBootstrap bootstrap = new ServerBootstrap();
-        bootstrap.group(new NioEventLoopGroup(), new NioEventLoopGroup())
-                .channel(NioServerSocketChannel.class)
-                .option(ChannelOption.SO_BACKLOG, 128)
-                .childOption(ChannelOption.SO_KEEPALIVE, true);
-        return bootstrap;
-    }
-}

+ 6 - 14
src/main/java/com/zhentao/information/controller/MessageController.java

@@ -1,10 +1,9 @@
 package com.zhentao.information.controller;
 package com.zhentao.information.controller;
 
 
 import com.alibaba.fastjson.JSON;
 import com.alibaba.fastjson.JSON;
-import com.zhentao.config.NullLogin;
+import com.zhentao.information.cache.ChannelCache;
 import com.zhentao.information.entity.ChatMessage;
 import com.zhentao.information.entity.ChatMessage;
 import com.zhentao.information.entity.Message;
 import com.zhentao.information.entity.Message;
-import com.zhentao.information.handler.WebSocketHandler;
 import com.zhentao.information.repository.ChatMessageRepository;
 import com.zhentao.information.repository.ChatMessageRepository;
 import com.zhentao.tool.TokenUtils;
 import com.zhentao.tool.TokenUtils;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.ChannelHandlerContext;
@@ -14,7 +13,6 @@ import org.springframework.web.bind.annotation.*;
 
 
 import javax.annotation.Resource;
 import javax.annotation.Resource;
 import java.util.List;
 import java.util.List;
-import java.util.Map;
 
 
 /**
 /**
  * 消息控制器
  * 消息控制器
@@ -25,7 +23,7 @@ import java.util.Map;
 public class MessageController {
 public class MessageController {
 
 
     @Resource
     @Resource
-    private WebSocketHandler webSocketHandler;
+    private ChannelCache channelCache;
 
 
     @Resource
     @Resource
     private ChatMessageRepository chatMessageRepository;
     private ChatMessageRepository chatMessageRepository;
@@ -48,7 +46,7 @@ public class MessageController {
         chatMessage.setFromUserId(message.getFromUserId());
         chatMessage.setFromUserId(message.getFromUserId());
         chatMessage.setToUserId(message.getToUserId());
         chatMessage.setToUserId(message.getToUserId());
         chatMessage.setContent(message.getContent());
         chatMessage.setContent(message.getContent());
-        chatMessage.setType(message.getType());
+        chatMessage.setType(String.valueOf(message.getType()));
         chatMessage.setTimestamp(System.currentTimeMillis());
         chatMessage.setTimestamp(System.currentTimeMillis());
         chatMessage.setIsRead(false);
         chatMessage.setIsRead(false);
         chatMessage.setChatId(chatId);
         chatMessage.setChatId(chatId);
@@ -57,8 +55,7 @@ public class MessageController {
         chatMessageRepository.save(chatMessage);
         chatMessageRepository.save(chatMessage);
 
 
         // 获取接收者的Channel
         // 获取接收者的Channel
-        Map<String, ChannelHandlerContext> userChannelMap = webSocketHandler.getUserChannelMap();
-        ChannelHandlerContext toUserCtx = userChannelMap.get(message.getToUserId());
+        ChannelHandlerContext toUserCtx = channelCache.getCache(message.getToUserId());
 
 
         if (toUserCtx != null) {
         if (toUserCtx != null) {
             // 发送消息给接收者
             // 发送消息给接收者
@@ -73,13 +70,9 @@ public class MessageController {
      * 获取两个用户之间的聊天记录
      * 获取两个用户之间的聊天记录
      */
      */
     @GetMapping("/history")
     @GetMapping("/history")
-    public List<ChatMessage> getChatHistory(@RequestHeader("token") String token,@RequestParam String userId2) {
+    public List<ChatMessage> getChatHistory(@RequestHeader("token") String token, @RequestParam String userId2) {
         String userIdFromToken = TokenUtils.getUserIdFromToken(token);
         String userIdFromToken = TokenUtils.getUserIdFromToken(token);
-        System.err.println(userIdFromToken);
-        String chatId = null;
-        if (userIdFromToken != null) {
-            chatId = generateChatId(userIdFromToken, userId2);
-        }
+        String chatId = generateChatId(userIdFromToken, userId2);
         return chatMessageRepository.findByChatId(chatId);
         return chatMessageRepository.findByChatId(chatId);
     }
     }
 
 
@@ -95,7 +88,6 @@ public class MessageController {
      * 生成聊天ID
      * 生成聊天ID
      */
      */
     private String generateChatId(String userId1, String userId2) {
     private String generateChatId(String userId1, String userId2) {
-        // 确保两个用户之间的聊天ID唯一,且与顺序无关
         return userId1.compareTo(userId2) < 0 ?
         return userId1.compareTo(userId2) < 0 ?
                 userId1 + "_" + userId2 :
                 userId1 + "_" + userId2 :
                 userId2 + "_" + userId1;
                 userId2 + "_" + userId1;

+ 1 - 1
src/main/java/com/zhentao/information/entity/ChatMessage.java

@@ -23,7 +23,7 @@ public class ChatMessage {
 
 
     private String content;
     private String content;
 
 
-    private Integer type;
+    private String type;
 
 
     private Long timestamp;
     private Long timestamp;
 
 

+ 14 - 10
src/main/java/com/zhentao/information/entity/Message.java

@@ -3,32 +3,36 @@ package com.zhentao.information.entity;
 import lombok.Data;
 import lombok.Data;
 
 
 /**
 /**
- * 消息实体类
+ * WebSocket消息实体类
  */
  */
 @Data
 @Data
 public class Message {
 public class Message {
     /**
     /**
+     * 消息类型
+     * connect: 连接消息
+     * text: 文本消息
+     * image: 图片消息
+     * voice: 语音消息
+     */
+    private String type;
+    
+    /**
      * 发送者ID
      * 发送者ID
      */
      */
     private String fromUserId;
     private String fromUserId;
-
+    
     /**
     /**
      * 接收者ID
      * 接收者ID
      */
      */
     private String toUserId;
     private String toUserId;
-
+    
     /**
     /**
      * 消息内容
      * 消息内容
      */
      */
     private String content;
     private String content;
-
-    /**
-     * 消息类型
-     */
-    private Integer type;
-
+    
     /**
     /**
-     * 发送时间
+     * 消息时间戳
      */
      */
     private Long timestamp;
     private Long timestamp;
 }
 }

+ 89 - 26
src/main/java/com/zhentao/information/handler/WebSocketHandler.java

@@ -1,56 +1,119 @@
 package com.zhentao.information.handler;
 package com.zhentao.information.handler;
 
 
 import com.alibaba.fastjson.JSON;
 import com.alibaba.fastjson.JSON;
+import com.zhentao.information.entity.ChatMessage;
 import com.zhentao.information.entity.Message;
 import com.zhentao.information.entity.Message;
+import com.zhentao.information.repository.ChatMessageRepository;
+import com.zhentao.information.service.WebSocketService;
+import io.netty.channel.ChannelHandler;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.SimpleChannelInboundHandler;
 import io.netty.channel.SimpleChannelInboundHandler;
 import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
 import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
 import lombok.extern.slf4j.Slf4j;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Component;
 import org.springframework.stereotype.Component;
 
 
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
+import javax.annotation.Resource;
 
 
 /**
 /**
  * WebSocket消息处理器
  * WebSocket消息处理器
+ * 处理WebSocket连接、消息接收和发送
  */
  */
 @Slf4j
 @Slf4j
 @Component
 @Component
+@ChannelHandler.Sharable
 public class WebSocketHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {
 public class WebSocketHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {
 
 
-    // 用户ID和Channel的映射关系
-    private static final Map<String, ChannelHandlerContext> USER_CHANNEL_MAP = new ConcurrentHashMap<>();
+    @Resource
+    private ChatMessageRepository chatMessageRepository;
+
+    @Resource
+    private WebSocketService webSocketService;
 
 
     /**
     /**
-     * 获取用户Channel映射
-     * @return 用户Channel映射
+     * 处理接收到的WebSocket消息
      */
      */
-    public Map<String, ChannelHandlerContext> getUserChannelMap() {
-        return USER_CHANNEL_MAP;
+    @Override
+    protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception {
+        String text = msg.text();
+        log.info("收到消息:{}", text);
+        try {
+            Message message = JSON.parseObject(text, Message.class);
+
+            // 如果是连接消息,处理token
+            if ("connect".equals(message.getType())) {
+                webSocketService.handleUserLogin(message.getContent(), ctx);
+                return;
+            }
+
+            // 处理普通消息
+            handleMessage(message);
+
+        } catch (Exception e) {
+            log.error("处理消息失败", e);
+        }
     }
     }
 
 
-    @Override
-    protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) {
-        String message = msg.text();
-        Message messageObj = JSON.parseObject(message, Message.class);
-
-        // 存储用户连接
-        USER_CHANNEL_MAP.put(messageObj.getFromUserId(), ctx);
-
-        // 获取接收者的Channel
-        ChannelHandlerContext toUserCtx = USER_CHANNEL_MAP.get(messageObj.getToUserId());
-        if (toUserCtx != null) {
-            // 发送消息给接收者
-            toUserCtx.channel().writeAndFlush(new TextWebSocketFrame(JSON.toJSONString(messageObj)));
-            log.info("消息已发送给用户: {}, 内容: {}", messageObj.getToUserId(), messageObj.getContent());
+    /**
+     * 处理普通消息
+     */
+    private void handleMessage(Message message) {
+        // 生成聊天ID
+        String chatId = generateChatId(message.getFromUserId(), message.getToUserId());
+
+        // 创建MongoDB消息对象
+        ChatMessage chatMessage = new ChatMessage();
+        chatMessage.setFromUserId(message.getFromUserId());
+        chatMessage.setToUserId(message.getToUserId());
+        chatMessage.setContent(message.getContent());
+        chatMessage.setType(String.valueOf(message.getType()));
+        chatMessage.setTimestamp(System.currentTimeMillis());
+        chatMessage.setIsRead(false);
+        chatMessage.setChatId(chatId);
+
+        // 保存消息到MongoDB
+        chatMessageRepository.save(chatMessage);
+
+        // 发送消息给接收者
+        boolean sent = webSocketService.sendMessageToUser(message.getToUserId(), message);
+        System.err.println("判断对方用户是否在线"+sent);
+        if (sent) {
+            log.info("消息已发送给用户: {}, 内容: {}", message.getToUserId(), message.getContent());
         } else {
         } else {
-            log.info("用户 {} 不在线", messageObj.getToUserId());
+            log.info("用户 {} 不在线,消息已保存到MongoDB", message.getToUserId());
         }
         }
     }
     }
 
 
+    /**
+     * 当新的WebSocket连接建立时调用
+     */
+    @Override
+    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
+        log.info("新的连接:{}", ctx.channel().id().asLongText());
+    }
+
+    /**
+     * 当WebSocket连接断开时调用
+     */
+    @Override
+    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
+        log.info("连接断开:{}", ctx.channel().id().asLongText());
+    }
+
+    /**
+     * 处理异常情况
+     */
     @Override
     @Override
-    public void handlerRemoved(ChannelHandlerContext ctx) {
-        // 用户断开连接时,移除映射关系
-        USER_CHANNEL_MAP.entrySet().removeIf(entry -> entry.getValue().equals(ctx));
+    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
+        log.error("WebSocket异常", cause);
+        ctx.close();
+    }
+
+    /**
+     * 生成聊天ID
+     */
+    private String generateChatId(String userId1, String userId2) {
+        return userId1.compareTo(userId2) < 0 ?
+                userId1 + "_" + userId2 :
+                userId2 + "_" + userId1;
     }
     }
 }
 }

+ 2 - 2
src/main/java/com/zhentao/information/repository/ChatMessageRepository.java

@@ -2,19 +2,19 @@ package com.zhentao.information.repository;
 
 
 import com.zhentao.information.entity.ChatMessage;
 import com.zhentao.information.entity.ChatMessage;
 import org.springframework.data.mongodb.repository.MongoRepository;
 import org.springframework.data.mongodb.repository.MongoRepository;
-import org.springframework.data.mongodb.repository.Query;
+import org.springframework.stereotype.Repository;
 
 
 import java.util.List;
 import java.util.List;
 
 
 /**
 /**
  * 聊天消息仓库
  * 聊天消息仓库
  */
  */
+@Repository
 public interface ChatMessageRepository extends MongoRepository<ChatMessage, String> {
 public interface ChatMessageRepository extends MongoRepository<ChatMessage, String> {
 
 
     /**
     /**
      * 查询两个用户之间的聊天记录
      * 查询两个用户之间的聊天记录
      */
      */
-    @Query("{'chatId': ?0}")
     List<ChatMessage> findByChatId(String chatId);
     List<ChatMessage> findByChatId(String chatId);
 
 
     /**
     /**

+ 4 - 3
src/main/java/com/zhentao/user/controller/UserController.java

@@ -4,7 +4,7 @@ import com.aliyun.oss.OSS;
 import com.aliyun.oss.OSSClientBuilder;
 import com.aliyun.oss.OSSClientBuilder;
 import com.aliyun.oss.model.PutObjectRequest;
 import com.aliyun.oss.model.PutObjectRequest;
 import com.zhentao.config.NullLogin;
 import com.zhentao.config.NullLogin;
-
+import com.zhentao.information.service.WebSocketService;
 import com.zhentao.osspicture.OssUtil;
 import com.zhentao.osspicture.OssUtil;
 import com.zhentao.tool.TokenUtils;
 import com.zhentao.tool.TokenUtils;
 import com.zhentao.user.domain.UserLogin;
 import com.zhentao.user.domain.UserLogin;
@@ -38,6 +38,7 @@ public class UserController {
     @Autowired
     @Autowired
     public OssUtil ossUtil;
     public OssUtil ossUtil;
 
 
+
     //注册
     //注册
     @PostMapping("/register")
     @PostMapping("/register")
     @NullLogin
     @NullLogin
@@ -64,8 +65,8 @@ public class UserController {
     @PostMapping("/UserPassLogin")
     @PostMapping("/UserPassLogin")
     @NullLogin
     @NullLogin
     public Result UserPassLogin(@RequestBody @Valid UserPassDto userPassDto) {
     public Result UserPassLogin(@RequestBody @Valid UserPassDto userPassDto) {
-
-        return userLoginService.UserPassLogin(userPassDto);
+        Result result = userLoginService.UserPassLogin(userPassDto);
+        return result;
     }
     }
 
 
 //    忘记密码
 //    忘记密码

+ 8 - 2
src/main/java/com/zhentao/user/service/impl/UserLoginServiceImpl.java

@@ -6,6 +6,7 @@ import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
 import com.zhentao.enums.ApiServerException;
 import com.zhentao.enums.ApiServerException;
 import com.zhentao.exception.AsynException;
 import com.zhentao.exception.AsynException;
 
 
+import com.zhentao.information.service.WebSocketService;
 import com.zhentao.tool.TokenUtils;
 import com.zhentao.tool.TokenUtils;
 import com.zhentao.user.domain.UserLogin;
 import com.zhentao.user.domain.UserLogin;
 import com.zhentao.user.dto.*;
 import com.zhentao.user.dto.*;
@@ -37,6 +38,8 @@ public class UserLoginServiceImpl extends ServiceImpl<UserLoginMapper, UserLogin
     private RedissonClient redissonClient;
     private RedissonClient redissonClient;
     @Autowired
     @Autowired
     private StringRedisTemplate stringRedisTemplate;
     private StringRedisTemplate stringRedisTemplate;
+    @Autowired
+    public WebSocketService webSocketService;
 
 
     //注册
     //注册
     @Override
     @Override
@@ -198,7 +201,6 @@ public class UserLoginServiceImpl extends ServiceImpl<UserLoginMapper, UserLogin
                 QueryWrapper<UserLogin> queryWrapper = new QueryWrapper<>();
                 QueryWrapper<UserLogin> queryWrapper = new QueryWrapper<>();
                 queryWrapper.eq("user_username",userPassDto.getUsername());
                 queryWrapper.eq("user_username",userPassDto.getUsername());
                 UserLogin one = this.getOne(queryWrapper);
                 UserLogin one = this.getOne(queryWrapper);
-                System.err.println(one.toString());
                 // 如果用户不存在,抛出异常
                 // 如果用户不存在,抛出异常
                 if (one==null){
                 if (one==null){
                     throw new AsynException(ApiServerException.NULL_USERNAME);
                     throw new AsynException(ApiServerException.NULL_USERNAME);
@@ -213,9 +215,13 @@ public class UserLoginServiceImpl extends ServiceImpl<UserLoginMapper, UserLogin
                 // 生成JWT令牌
                 // 生成JWT令牌
                 String jwtToken = TokenUtils.generateToken(one.getId()+"");
                 String jwtToken = TokenUtils.generateToken(one.getId()+"");
                 stringRedisTemplate.opsForValue().set(one.getId().toString(),jwtToken);
                 stringRedisTemplate.opsForValue().set(one.getId().toString(),jwtToken);
-                System.err.println(jwtToken);
                 System.err.println(stringRedisTemplate.opsForValue().get(one.getId().toString()));
                 System.err.println(stringRedisTemplate.opsForValue().get(one.getId().toString()));
                 // 返回登录成功结果和JWT令牌
                 // 返回登录成功结果和JWT令牌
+
+                // 将用户ID和token存储到WebSocketService中
+                webSocketService.storeUserToken(one.getId()+"", jwtToken);
+
+
                 return Result.OK("登录成功",jwtToken);
                 return Result.OK("登录成功",jwtToken);
             }else {
             }else {
                 // 如果获取锁超时,返回错误信息
                 // 如果获取锁超时,返回错误信息

+ 4 - 2
src/main/resources/application.yml

@@ -2,6 +2,8 @@ server:
   port: 8081
   port: 8081
 netty:
 netty:
   port: 8888
   port: 8888
+  websocket:
+    port: 9099
 spring:
 spring:
   application:
   application:
     name: im-server
     name: im-server
@@ -26,9 +28,9 @@ spring:
   # MongoDB配置
   # MongoDB配置
   data:
   data:
     mongodb:
     mongodb:
-      host: 47.110.46.22
+      host: localhost
       port: 27017
       port: 27017
-      database: im_message_db
+      database: im_db
   redis:
   redis:
     host: 47.110.46.22
     host: 47.110.46.22
     port: 6379
     port: 6379