|
@@ -0,0 +1,137 @@
|
|
|
|
+package com.zhentao.Ai.advertisement;
|
|
|
|
+
|
|
|
|
+import com.google.gson.Gson;
|
|
|
|
+import com.google.gson.JsonObject;
|
|
|
|
+import com.google.gson.JsonParser;
|
|
|
|
+import com.zhentao.Ai.dto.DeeseekRequest;
|
|
|
|
+import org.springframework.web.socket.TextMessage;
|
|
|
|
+import org.springframework.web.socket.WebSocketSession;
|
|
|
|
+import org.springframework.web.socket.handler.TextWebSocketHandler;
|
|
|
|
+
|
|
|
|
+import java.io.BufferedReader;
|
|
|
|
+import java.io.IOException;
|
|
|
|
+import java.io.InputStreamReader;
|
|
|
|
+import java.io.OutputStream;
|
|
|
|
+import java.net.HttpURLConnection;
|
|
|
|
+import java.net.SocketTimeoutException;
|
|
|
|
+import java.net.URL;
|
|
|
|
+import java.nio.charset.StandardCharsets;
|
|
|
|
+import java.util.ArrayList;
|
|
|
|
+import java.util.List;
|
|
|
|
+import java.util.concurrent.ExecutorService;
|
|
|
|
+import java.util.concurrent.Executors;
|
|
|
|
+
|
|
|
|
+public class DeepSeekWebSocketHandler extends TextWebSocketHandler {
|
|
|
|
+
|
|
|
|
+ private static final Gson gson = new Gson();
|
|
|
|
+ private static final String DEEPSEEK_API_URL = "https://api.deepseek.com/chat/completions";
|
|
|
|
+ private static final String API_KEY = "sk-df51dab7323441998d41f18494098ddc"; // 替换为你的API密钥
|
|
|
|
+
|
|
|
|
+ // ...其他成员变量...
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
|
|
|
|
+ String question = message.getPayload();
|
|
|
|
+ if (question == null || question.trim().isEmpty()) {
|
|
|
|
+ sendErrorMessage(session, "问题不能为空");
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // 使用线程池管理请求(避免频繁创建线程)
|
|
|
|
+ ExecutorService executor = Executors.newSingleThreadExecutor();
|
|
|
|
+ executor.submit(() -> processStreamingResponse(session, question));
|
|
|
|
+ executor.shutdown();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private void processStreamingResponse(WebSocketSession session, String question) {
|
|
|
|
+ HttpURLConnection connection = null;
|
|
|
|
+ try {
|
|
|
|
+ URL url = new URL(DEEPSEEK_API_URL);
|
|
|
|
+ connection = (HttpURLConnection) url.openConnection();
|
|
|
|
+ connection.setRequestMethod("POST");
|
|
|
|
+ connection.setRequestProperty("Content-Type", "application/json");
|
|
|
|
+ connection.setRequestProperty("Authorization", "Bearer " + API_KEY);
|
|
|
|
+ connection.setRequestProperty("Accept", "text/event-stream");
|
|
|
|
+ connection.setDoOutput(true);
|
|
|
|
+ connection.setDoInput(true);
|
|
|
|
+ connection.setUseCaches(false);
|
|
|
|
+ connection.setConnectTimeout(5000); // 5秒连接超时
|
|
|
|
+ connection.setReadTimeout(30000); // 30秒读取超时
|
|
|
|
+
|
|
|
|
+ // 发送请求体
|
|
|
|
+ try (OutputStream os = connection.getOutputStream()) {
|
|
|
|
+ os.write(buildRequestBody(question).getBytes(StandardCharsets.UTF_8));
|
|
|
|
+ os.flush();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // 处理流式响应
|
|
|
|
+ try (BufferedReader reader = new BufferedReader(
|
|
|
|
+ new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8))) {
|
|
|
|
+
|
|
|
|
+ String line;
|
|
|
|
+ while ((line = reader.readLine()) != null && session.isOpen()) {
|
|
|
|
+ if (line.startsWith("data: ") && !line.equals("data: [DONE]")) {
|
|
|
|
+ String content = parseContent(line.substring(6));
|
|
|
|
+ if (content != null) {
|
|
|
|
+ session.sendMessage(new TextMessage(content));
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ } catch (SocketTimeoutException e) {
|
|
|
|
+ sendErrorMessage(session, "API响应超时,请重试");
|
|
|
|
+ } catch (IOException e) {
|
|
|
|
+ sendErrorMessage(session, "网络错误: " + e.getMessage());
|
|
|
|
+ } finally {
|
|
|
|
+ if (connection != null) connection.disconnect();
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private String buildRequestBody(String question) {
|
|
|
|
+ List<DeeseekRequest.Message> messages = new ArrayList<>();
|
|
|
|
+ messages.add(DeeseekRequest.Message.builder()
|
|
|
|
+ .role("system")
|
|
|
|
+ .content("你是一个佳佳聊天小助手,请用中文回答")
|
|
|
|
+ .build());
|
|
|
|
+ messages.add(DeeseekRequest.Message.builder()
|
|
|
|
+ .role("user")
|
|
|
|
+ .content(question)
|
|
|
|
+ .build());
|
|
|
|
+
|
|
|
|
+ return gson.toJson(DeeseekRequest.builder()
|
|
|
|
+ .model("deepseek-chat")
|
|
|
|
+ .messages(messages)
|
|
|
|
+ .stream(true)
|
|
|
|
+ .build());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private String parseContent(String json) {
|
|
|
|
+ try {
|
|
|
|
+ JsonObject obj = JsonParser.parseString(json).getAsJsonObject();
|
|
|
|
+ if (obj.has("choices")) {
|
|
|
|
+ JsonObject delta = obj.getAsJsonArray("choices")
|
|
|
|
+ .get(0).getAsJsonObject()
|
|
|
|
+ .getAsJsonObject("delta");
|
|
|
|
+ if (delta.has("content")) {
|
|
|
|
+ System.err.println(delta.get("content").getAsString());
|
|
|
|
+ return delta.get("content").getAsString();
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return null;
|
|
|
|
+ } catch (Exception e) {
|
|
|
|
+ return null;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private void sendErrorMessage(WebSocketSession session, String message) {
|
|
|
|
+ try {
|
|
|
|
+ if (session.isOpen()) {
|
|
|
|
+ session.sendMessage(new TextMessage(
|
|
|
|
+ "{\"error\": \"" + message + "\"}"
|
|
|
|
+ ));
|
|
|
|
+ }
|
|
|
|
+ } catch (Exception e) {
|
|
|
|
+ e.printStackTrace();
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|