zhentao 1 week ago
parent
commit
ce6b84d613

+ 16 - 0
pom.xml

@@ -15,6 +15,22 @@
     </properties>
     <dependencies>
         <dependency>
+            <groupId>org.springframework.boot</groupId>
+            <artifactId>spring-boot-starter-websocket</artifactId>
+        </dependency>
+        <dependency>
+            <groupId>com.alibaba</groupId>
+            <artifactId>dashscope-sdk-java</artifactId>
+            <version>2.18.2</version>
+            <exclusions>
+                <exclusion>
+                    <groupId>org.slf4j</groupId>
+                    <artifactId>slf4j-simple</artifactId>
+                </exclusion>
+            </exclusions>
+        </dependency>
+
+        <dependency>
             <groupId>com.aliyun.oss</groupId>
             <artifactId>aliyun-sdk-oss</artifactId>
             <version>3.17.4</version>

+ 137 - 0
src/main/java/com/zhentao/Ai/advertisement/DeepSeekWebSocketHandler.java

@@ -0,0 +1,137 @@
+package com.zhentao.Ai.advertisement;
+
+import com.google.gson.Gson;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParser;
+import com.zhentao.Ai.dto.DeeseekRequest;
+import org.springframework.web.socket.TextMessage;
+import org.springframework.web.socket.WebSocketSession;
+import org.springframework.web.socket.handler.TextWebSocketHandler;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.net.HttpURLConnection;
+import java.net.SocketTimeoutException;
+import java.net.URL;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+public class DeepSeekWebSocketHandler extends TextWebSocketHandler {
+
+    private static final Gson gson = new Gson();
+    private static final String DEEPSEEK_API_URL = "https://api.deepseek.com/chat/completions";
+    private static final String API_KEY = "sk-df51dab7323441998d41f18494098ddc"; // 替换为你的API密钥
+
+        // ...其他成员变量...
+
+    @Override
+    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
+        String question = message.getPayload();
+        if (question == null || question.trim().isEmpty()) {
+            sendErrorMessage(session, "问题不能为空");
+            return;
+        }
+
+        // 使用线程池管理请求(避免频繁创建线程)
+        ExecutorService executor = Executors.newSingleThreadExecutor();
+        executor.submit(() -> processStreamingResponse(session, question));
+        executor.shutdown();
+    }
+
+    private void processStreamingResponse(WebSocketSession session, String question) {
+        HttpURLConnection connection = null;
+        try {
+            URL url = new URL(DEEPSEEK_API_URL);
+            connection = (HttpURLConnection) url.openConnection();
+            connection.setRequestMethod("POST");
+            connection.setRequestProperty("Content-Type", "application/json");
+            connection.setRequestProperty("Authorization", "Bearer " + API_KEY);
+            connection.setRequestProperty("Accept", "text/event-stream");
+            connection.setDoOutput(true);
+            connection.setDoInput(true);
+            connection.setUseCaches(false);
+            connection.setConnectTimeout(5000); // 5秒连接超时
+            connection.setReadTimeout(30000);  // 30秒读取超时
+
+            // 发送请求体
+            try (OutputStream os = connection.getOutputStream()) {
+                os.write(buildRequestBody(question).getBytes(StandardCharsets.UTF_8));
+                os.flush();
+            }
+
+            // 处理流式响应
+            try (BufferedReader reader = new BufferedReader(
+                    new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8))) {
+
+                String line;
+                while ((line = reader.readLine()) != null && session.isOpen()) {
+                    if (line.startsWith("data: ") && !line.equals("data: [DONE]")) {
+                        String content = parseContent(line.substring(6));
+                        if (content != null) {
+                            session.sendMessage(new TextMessage(content));
+                        }
+                    }
+                }
+            }
+        } catch (SocketTimeoutException e) {
+            sendErrorMessage(session, "API响应超时,请重试");
+        } catch (IOException e) {
+            sendErrorMessage(session, "网络错误: " + e.getMessage());
+        } finally {
+            if (connection != null) connection.disconnect();
+        }
+    }
+
+        private String buildRequestBody(String question) {
+            List<DeeseekRequest.Message> messages = new ArrayList<>();
+            messages.add(DeeseekRequest.Message.builder()
+                    .role("system")
+                    .content("你是一个佳佳聊天小助手,请用中文回答")
+                    .build());
+            messages.add(DeeseekRequest.Message.builder()
+                    .role("user")
+                    .content(question)
+                    .build());
+
+            return gson.toJson(DeeseekRequest.builder()
+                    .model("deepseek-chat")
+                    .messages(messages)
+                    .stream(true)
+                    .build());
+        }
+
+        private String parseContent(String json) {
+            try {
+                JsonObject obj = JsonParser.parseString(json).getAsJsonObject();
+                if (obj.has("choices")) {
+                    JsonObject delta = obj.getAsJsonArray("choices")
+                            .get(0).getAsJsonObject()
+                            .getAsJsonObject("delta");
+                    if (delta.has("content")) {
+                        System.err.println(delta.get("content").getAsString());
+                        return delta.get("content").getAsString();
+                    }
+                }
+                return null;
+            } catch (Exception e) {
+                return null;
+            }
+        }
+
+        private void sendErrorMessage(WebSocketSession session, String message) {
+            try {
+                if (session.isOpen()) {
+                    session.sendMessage(new TextMessage(
+                            "{\"error\": \"" + message + "\"}"
+                    ));
+                }
+            } catch (Exception e) {
+                e.printStackTrace();
+            }
+        }
+    }

+ 24 - 0
src/main/java/com/zhentao/Ai/config/WebSocketConfig.java

@@ -0,0 +1,24 @@
+package com.zhentao.Ai.config;
+
+import com.zhentao.Ai.advertisement.DeepSeekWebSocketHandler;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.socket.config.annotation.EnableWebSocket;
+import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
+import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
+import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
+import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
+
+// WebSocketConfig.java
+@Configuration
+@EnableWebSocket
+public class WebSocketConfig implements WebSocketConfigurer {
+
+    @Override
+    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
+        registry.addHandler(new DeepSeekWebSocketHandler(), "/ws/deepseek")
+                .setAllowedOrigins("*")
+                .setHandshakeHandler(new DefaultHandshakeHandler())
+                .addInterceptors(new HttpSessionHandshakeInterceptor())
+                .setAllowedOriginPatterns("*"); // 更灵活的跨域控制
+    }
+}

+ 21 - 0
src/main/java/com/zhentao/Ai/dto/DeeseekRequest.java

@@ -0,0 +1,21 @@
+package com.zhentao.Ai.dto;
+
+import lombok.Builder;
+import lombok.Data;
+
+import java.util.List;
+
+@Data
+@Builder
+public class DeeseekRequest {
+    private String model;
+    private List<Message> messages;
+    private boolean stream;  // 关键:新增 stream 字段,控制是否流式输出
+
+    @Data
+    @Builder
+    public static class Message {
+        private String role;
+        private String content;
+    }
+}

+ 1 - 3
src/main/java/com/zhentao/userRelationships/service/impl/UserRelationshipsServiceImpl.java

@@ -9,12 +9,10 @@ import com.zhentao.user.domain.UserLogin;
 import com.zhentao.user.mapper.UserLoginMapper;
 import com.zhentao.userRelationships.domain.UserRelationships;
 import com.zhentao.userRelationships.domain.UserRequest;
-import com.zhentao.userRelationships.domain.UserShouye;
-import com.zhentao.userRelationships.dto.FriendDto;
+
 
 import com.zhentao.userRelationships.dto.FriendsTDto;
 import com.zhentao.userRelationships.mapper.UserRequestMapper;
-import com.zhentao.userRelationships.mapper.UserShouyeMapper;
 import com.zhentao.userRelationships.service.UserRelationshipsService;
 import com.zhentao.userRelationships.mapper.UserRelationshipsMapper;
 

+ 1 - 1
src/main/resources/application.yml

@@ -30,7 +30,7 @@ spring:
     mongodb:
       host: 101.200.59.170
       port: 27017
-      database: chat_messages
+      database: im_message_db
   redis:
     host: 101.200.59.170
     port: 6379