package com.zhentao.information.service; import com.alibaba.fastjson.JSON; import com.zhentao.groups.MongoDB.pojo.Message; import com.zhentao.information.cache.ChannelCache; import com.zhentao.information.cache.GroupChannelCache; import com.zhentao.information.cache.GroupMemberCache; import com.zhentao.tool.TokenUtils; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.group.ChannelGroup; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Lazy; import org.springframework.stereotype.Service; import com.zhentao.groups.service.GroupsService; import com.zhentao.groups.dto.GroupDto; import com.zhentao.information.entity.ChatMessage; import com.zhentao.information.repository.ChatMessageRepository; import javax.annotation.Resource; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; /** * WebSocket服务类 * 处理WebSocket连接、消息发送等业务逻辑 */ @Slf4j @Service public class WebSocketService { @Resource private ChannelCache channelCache; @Resource private GroupChannelCache groupChannelCache; @Resource private GroupMemberCache groupMemberCache; @Autowired @Lazy private GroupsService groupsService; @Autowired private ChatMessageRepository chatMessageRepository; // 存储用户token的Map private final Map userTokenMap = new ConcurrentHashMap<>(); private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); /** * 存储用户token * @param userId 用户ID * @param token 用户token */ public void storeUserToken(String userId, String token) { userTokenMap.put(userId, token); log.info("用户 {} 的token已存储", userId); } /** * 获取用户token * @param userId 用户ID * @return 用户token */ public String getUserToken(String userId) { return userTokenMap.get(userId); } /** * 处理用户登录 * @param token 用户token * @param ctx Channel上下文 * @return 用户ID,如果登录失败返回null */ public String handleUserLogin(String token, ChannelHandlerContext ctx) { String userId = TokenUtils.getUserIdFromToken(token); if (userId != null) { // 验证token是否与存储的token匹配 String storedToken = userTokenMap.get(userId); if (storedToken != null && storedToken.equals(token)) { // 将用户ID和Channel绑定 channelCache.addCache(userId, ctx); log.info("用户 {} 连接成功", userId); // 发送连接成功消息 Message response = new Message(); response.setType("connect_success"); response.setContent("连接成功"); ctx.channel().writeAndFlush(new TextWebSocketFrame(JSON.toJSONString(response))); return userId; } else { log.error("用户 {} 的token验证失败", userId); ctx.close(); } } else { log.error("无效的token"); ctx.close(); } return null; } /** * 发送消息给指定用户 * @param userId 接收者用户ID * @param message 消息内容 * @return 是否发送成功 */ public boolean sendMessageToUser(String userId, Message message) { if (message.getType() == null) { message.setType("text"); } ChannelHandlerContext ctx = channelCache.getCache(userId); if (ctx != null && ctx.channel().isActive()) { try { String messageJson = JSON.toJSONString(message); log.info("发送消息给用户 {}: {}", userId, messageJson); ctx.channel().writeAndFlush(new TextWebSocketFrame(messageJson)); return true; } catch (Exception e) { log.error("发送消息给用户 {} 失败", userId, e); retrySendMessage(userId, message); return false; } } else { log.info("用户 {} 不在线,消息将保存到数据库", userId); return false; } } private void retrySendMessage(String userId, Message message) { scheduler.schedule(() -> { ChannelHandlerContext ctx = channelCache.getCache(userId); if (ctx != null && ctx.channel().isActive()) { try { String messageJson = JSON.toJSONString(message); log.info("重试发送消息给用户 {}: {}", userId, messageJson); ctx.channel().writeAndFlush(new TextWebSocketFrame(messageJson)); } catch (Exception e) { log.error("重试发送消息给用户 {} 失败", userId, e); } } }, 1, TimeUnit.SECONDS); } /** * 广播消息给所有在线用户 * @param message 消息内容 */ public void broadcastMessage(Message message) { channelCache.getAllCache().forEach((userId, ctx) -> { if (ctx.channel().isActive()) { try { String messageJson = JSON.toJSONString(message); log.info("广播消息给用户 {}: {}", userId, messageJson); ctx.channel().writeAndFlush(new TextWebSocketFrame(messageJson)); } catch (Exception e) { log.error("广播消息给用户 {} 失败", userId, e); } } }); } /** * 检查用户是否在线 * @param userId 用户ID * @return 是否在线 */ public boolean isUserOnline(String userId) { ChannelHandlerContext ctx = channelCache.getCache(userId); return ctx != null && ctx.channel().isActive(); } /** * 处理群聊消息 * @param message 群聊消息 * @return 是否发送成功 */ public boolean handleGroupMessage(Message message) { Long groupId = message.getGroupId(); if (groupId == null) { log.error("群聊消息缺少群ID"); return false; } // 存储群聊消息到MongoDB ChatMessage chatMessage = new ChatMessage(); chatMessage.setFromUserId(message.getFromUserId()); chatMessage.setToUserId(String.valueOf(groupId)); chatMessage.setContent(message.getContent()); chatMessage.setType("group_chat"); chatMessage.setTimestamp(System.currentTimeMillis()); chatMessage.setIsRead(false); chatMessage.setChatId("group_" + groupId); chatMessage.setImageurl(message.getImageurl()); chatMessageRepository.save(chatMessage); // 获取群成员 List groupList = groupsService.getList(); List groupMembers = null; for (GroupDto group : groupList) { if (group.getGroupId().equals(groupId)) { groupMembers = group.getUid(); break; } } if (groupMembers == null || groupMembers.isEmpty()) { log.error("群 {} 不存在或没有成员", groupId); return false; } boolean allSent = true; for (Long memberId : groupMembers) { String memberIdStr = String.valueOf(memberId); if (!memberIdStr.equals(message.getFromUserId())) { ChannelHandlerContext ctx = channelCache.getCache(memberIdStr); if (ctx != null && ctx.channel().isActive()) { try { String messageJson = JSON.toJSONString(message); log.info("发送群消息给用户 {}: {}", memberIdStr, messageJson); ctx.channel().writeAndFlush(new TextWebSocketFrame(messageJson)); } catch (Exception e) { log.error("发送群消息给用户 {} 失败", memberId, e); allSent = false; } } } } log.info("群 {} 的消息已广播,群成员数:{}", groupId, groupMembers.size()); return allSent; } /** * 用户登录时,将其加入所有群聊的ChannelGroup * @param userId 用户ID */ public void joinAllGroups(String userId) { Long userIdLong = Long.valueOf(userId); Map> allGroups = groupMemberCache.getAllGroupMembers(); allGroups.forEach((groupId, members) -> { if (members.contains(userIdLong)) { addUserToGroup(groupId, userId); } }); log.info("用户 {} 已加入所有群聊的ChannelGroup", userId); } /** * 将用户添加到群聊ChannelGroup * @param groupId 群ID * @param userId 用户ID * @return 是否添加成功 */ public boolean addUserToGroup(Long groupId, String userId) { // 验证用户是否在群中 if (!groupMemberCache.isUserInGroup(groupId, Long.valueOf(userId))) { log.error("用户 {} 不在群 {} 中", userId, groupId); return false; } ChannelHandlerContext ctx = channelCache.getCache(userId); if (ctx == null || !ctx.channel().isActive()) { log.error("用户 {} 不在线", userId); return false; } ChannelGroup channelGroup = groupChannelCache.getGroup(groupId); if (channelGroup == null) { channelGroup = groupChannelCache.addGroup(groupId); } channelGroup.add(ctx.channel()); log.info("用户 {} 已添加到群 {} 的ChannelGroup", userId, groupId); return true; } /** * 将用户从群聊ChannelGroup中移除 * @param groupId 群ID * @param userId 用户ID * @return 是否移除成功 */ public boolean removeUserFromGroup(Long groupId, String userId) { ChannelHandlerContext ctx = channelCache.getCache(userId); if (ctx == null || !ctx.channel().isActive()) { return false; } ChannelGroup channelGroup = groupChannelCache.getGroup(groupId); if (channelGroup != null) { channelGroup.remove(ctx.channel()); log.info("用户 {} 已从群 {} 的ChannelGroup移除", userId, groupId); return true; } return false; } public void removeUserConnection(ChannelHandlerContext ctx) { // 从ChannelCache中移除用户连接 channelCache.removeCache(ctx); // 从所有群组中移除用户 groupChannelCache.getAllGroups().forEach((groupId, channelGroup) -> { channelGroup.remove(ctx.channel()); }); } }