WebSocketService.java 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. package com.zhentao.information.service;
  2. import com.alibaba.fastjson.JSON;
  3. import com.zhentao.groups.MongoDB.pojo.Message;
  4. import com.zhentao.information.cache.ChannelCache;
  5. import com.zhentao.information.cache.GroupChannelCache;
  6. import com.zhentao.information.cache.GroupMemberCache;
  7. import com.zhentao.tool.TokenUtils;
  8. import io.netty.channel.ChannelHandlerContext;
  9. import io.netty.channel.group.ChannelGroup;
  10. import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
  11. import lombok.extern.slf4j.Slf4j;
  12. import org.springframework.beans.factory.annotation.Autowired;
  13. import org.springframework.context.annotation.Lazy;
  14. import org.springframework.stereotype.Service;
  15. import com.zhentao.groups.service.GroupsService;
  16. import com.zhentao.groups.dto.GroupDto;
  17. import com.zhentao.information.entity.ChatMessage;
  18. import com.zhentao.information.repository.ChatMessageRepository;
  19. import javax.annotation.Resource;
  20. import java.util.List;
  21. import java.util.Map;
  22. import java.util.concurrent.ConcurrentHashMap;
  23. import java.util.concurrent.Executors;
  24. import java.util.concurrent.ScheduledExecutorService;
  25. import java.util.concurrent.TimeUnit;
  26. /**
  27. * WebSocket服务类
  28. * 处理WebSocket连接、消息发送等业务逻辑
  29. */
  30. @Slf4j
  31. @Service
  32. public class WebSocketService {
  33. @Resource
  34. private ChannelCache channelCache;
  35. @Resource
  36. private GroupChannelCache groupChannelCache;
  37. @Resource
  38. private GroupMemberCache groupMemberCache;
  39. @Autowired
  40. @Lazy
  41. private GroupsService groupsService;
  42. @Autowired
  43. private ChatMessageRepository chatMessageRepository;
  44. // 存储用户token的Map
  45. private final Map<String, String> userTokenMap = new ConcurrentHashMap<>();
  46. private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
  47. /**
  48. * 存储用户token
  49. * @param userId 用户ID
  50. * @param token 用户token
  51. */
  52. public void storeUserToken(String userId, String token) {
  53. userTokenMap.put(userId, token);
  54. log.info("用户 {} 的token已存储", userId);
  55. }
  56. /**
  57. * 获取用户token
  58. * @param userId 用户ID
  59. * @return 用户token
  60. */
  61. public String getUserToken(String userId) {
  62. return userTokenMap.get(userId);
  63. }
  64. /**
  65. * 处理用户登录
  66. * @param token 用户token
  67. * @param ctx Channel上下文
  68. * @return 用户ID,如果登录失败返回null
  69. */
  70. public String handleUserLogin(String token, ChannelHandlerContext ctx) {
  71. String userId = TokenUtils.getUserIdFromToken(token);
  72. if (userId != null) {
  73. // 验证token是否与存储的token匹配
  74. String storedToken = userTokenMap.get(userId);
  75. if (storedToken != null && storedToken.equals(token)) {
  76. // 将用户ID和Channel绑定
  77. channelCache.addCache(userId, ctx);
  78. log.info("用户 {} 连接成功", userId);
  79. // 发送连接成功消息
  80. Message response = new Message();
  81. response.setType("connect_success");
  82. response.setContent("连接成功");
  83. ctx.channel().writeAndFlush(new TextWebSocketFrame(JSON.toJSONString(response)));
  84. return userId;
  85. } else {
  86. log.error("用户 {} 的token验证失败", userId);
  87. ctx.close();
  88. }
  89. } else {
  90. log.error("无效的token");
  91. ctx.close();
  92. }
  93. return null;
  94. }
  95. /**
  96. * 发送消息给指定用户
  97. * @param userId 接收者用户ID
  98. * @param message 消息内容
  99. * @return 是否发送成功
  100. */
  101. public boolean sendMessageToUser(String userId, Message message) {
  102. if (message.getType() == null) {
  103. message.setType("text");
  104. }
  105. ChannelHandlerContext ctx = channelCache.getCache(userId);
  106. if (ctx != null && ctx.channel().isActive()) {
  107. try {
  108. String messageJson = JSON.toJSONString(message);
  109. log.info("发送消息给用户 {}: {}", userId, messageJson);
  110. ctx.channel().writeAndFlush(new TextWebSocketFrame(messageJson));
  111. return true;
  112. } catch (Exception e) {
  113. log.error("发送消息给用户 {} 失败", userId, e);
  114. retrySendMessage(userId, message);
  115. return false;
  116. }
  117. } else {
  118. log.info("用户 {} 不在线,消息将保存到数据库", userId);
  119. return false;
  120. }
  121. }
  122. private void retrySendMessage(String userId, Message message) {
  123. scheduler.schedule(() -> {
  124. ChannelHandlerContext ctx = channelCache.getCache(userId);
  125. if (ctx != null && ctx.channel().isActive()) {
  126. try {
  127. String messageJson = JSON.toJSONString(message);
  128. log.info("重试发送消息给用户 {}: {}", userId, messageJson);
  129. ctx.channel().writeAndFlush(new TextWebSocketFrame(messageJson));
  130. } catch (Exception e) {
  131. log.error("重试发送消息给用户 {} 失败", userId, e);
  132. }
  133. }
  134. }, 1, TimeUnit.SECONDS);
  135. }
  136. /**
  137. * 广播消息给所有在线用户
  138. * @param message 消息内容
  139. */
  140. public void broadcastMessage(Message message) {
  141. channelCache.getAllCache().forEach((userId, ctx) -> {
  142. if (ctx.channel().isActive()) {
  143. try {
  144. String messageJson = JSON.toJSONString(message);
  145. log.info("广播消息给用户 {}: {}", userId, messageJson);
  146. ctx.channel().writeAndFlush(new TextWebSocketFrame(messageJson));
  147. } catch (Exception e) {
  148. log.error("广播消息给用户 {} 失败", userId, e);
  149. }
  150. }
  151. });
  152. }
  153. /**
  154. * 检查用户是否在线
  155. * @param userId 用户ID
  156. * @return 是否在线
  157. */
  158. public boolean isUserOnline(String userId) {
  159. ChannelHandlerContext ctx = channelCache.getCache(userId);
  160. return ctx != null && ctx.channel().isActive();
  161. }
  162. /**
  163. * 处理群聊消息
  164. * @param message 群聊消息
  165. * @return 是否发送成功
  166. */
  167. public boolean handleGroupMessage(Message message) {
  168. Long groupId = message.getGroupId();
  169. if (groupId == null) {
  170. log.error("群聊消息缺少群ID");
  171. return false;
  172. }
  173. // 存储群聊消息到MongoDB
  174. ChatMessage chatMessage = new ChatMessage();
  175. chatMessage.setFromUserId(message.getFromUserId());
  176. chatMessage.setToUserId(String.valueOf(groupId));
  177. chatMessage.setContent(message.getContent());
  178. chatMessage.setType("group_chat");
  179. chatMessage.setTimestamp(System.currentTimeMillis());
  180. chatMessage.setIsRead(false);
  181. chatMessage.setChatId("group_" + groupId);
  182. chatMessage.setImageurl(message.getImageurl());
  183. chatMessageRepository.save(chatMessage);
  184. // 获取群成员
  185. List<GroupDto> groupList = groupsService.getList();
  186. List<Long> groupMembers = null;
  187. for (GroupDto group : groupList) {
  188. if (group.getGroupId().equals(groupId)) {
  189. groupMembers = group.getUid();
  190. break;
  191. }
  192. }
  193. if (groupMembers == null || groupMembers.isEmpty()) {
  194. log.error("群 {} 不存在或没有成员", groupId);
  195. return false;
  196. }
  197. boolean allSent = true;
  198. for (Long memberId : groupMembers) {
  199. String memberIdStr = String.valueOf(memberId);
  200. if (!memberIdStr.equals(message.getFromUserId())) {
  201. ChannelHandlerContext ctx = channelCache.getCache(memberIdStr);
  202. if (ctx != null && ctx.channel().isActive()) {
  203. try {
  204. String messageJson = JSON.toJSONString(message);
  205. log.info("发送群消息给用户 {}: {}", memberIdStr, messageJson);
  206. ctx.channel().writeAndFlush(new TextWebSocketFrame(messageJson));
  207. } catch (Exception e) {
  208. log.error("发送群消息给用户 {} 失败", memberId, e);
  209. allSent = false;
  210. }
  211. }
  212. }
  213. }
  214. log.info("群 {} 的消息已广播,群成员数:{}", groupId, groupMembers.size());
  215. return allSent;
  216. }
  217. /**
  218. * 用户登录时,将其加入所有群聊的ChannelGroup
  219. * @param userId 用户ID
  220. */
  221. public void joinAllGroups(String userId) {
  222. Long userIdLong = Long.valueOf(userId);
  223. Map<Long, List<Long>> allGroups = groupMemberCache.getAllGroupMembers();
  224. allGroups.forEach((groupId, members) -> {
  225. if (members.contains(userIdLong)) {
  226. addUserToGroup(groupId, userId);
  227. }
  228. });
  229. log.info("用户 {} 已加入所有群聊的ChannelGroup", userId);
  230. }
  231. /**
  232. * 将用户添加到群聊ChannelGroup
  233. * @param groupId 群ID
  234. * @param userId 用户ID
  235. * @return 是否添加成功
  236. */
  237. public boolean addUserToGroup(Long groupId, String userId) {
  238. // 验证用户是否在群中
  239. if (!groupMemberCache.isUserInGroup(groupId, Long.valueOf(userId))) {
  240. log.error("用户 {} 不在群 {} 中", userId, groupId);
  241. return false;
  242. }
  243. ChannelHandlerContext ctx = channelCache.getCache(userId);
  244. if (ctx == null || !ctx.channel().isActive()) {
  245. log.error("用户 {} 不在线", userId);
  246. return false;
  247. }
  248. ChannelGroup channelGroup = groupChannelCache.getGroup(groupId);
  249. if (channelGroup == null) {
  250. channelGroup = groupChannelCache.addGroup(groupId);
  251. }
  252. channelGroup.add(ctx.channel());
  253. log.info("用户 {} 已添加到群 {} 的ChannelGroup", userId, groupId);
  254. return true;
  255. }
  256. /**
  257. * 将用户从群聊ChannelGroup中移除
  258. * @param groupId 群ID
  259. * @param userId 用户ID
  260. * @return 是否移除成功
  261. */
  262. public boolean removeUserFromGroup(Long groupId, String userId) {
  263. ChannelHandlerContext ctx = channelCache.getCache(userId);
  264. if (ctx == null || !ctx.channel().isActive()) {
  265. return false;
  266. }
  267. ChannelGroup channelGroup = groupChannelCache.getGroup(groupId);
  268. if (channelGroup != null) {
  269. channelGroup.remove(ctx.channel());
  270. log.info("用户 {} 已从群 {} 的ChannelGroup移除", userId, groupId);
  271. return true;
  272. }
  273. return false;
  274. }
  275. public void removeUserConnection(ChannelHandlerContext ctx) {
  276. // 从ChannelCache中移除用户连接
  277. channelCache.removeCache(ctx);
  278. // 从所有群组中移除用户
  279. groupChannelCache.getAllGroups().forEach((groupId, channelGroup) -> {
  280. channelGroup.remove(ctx.channel());
  281. });
  282. }
  283. }