diff --git a/src/main/java/com/example/springaidemo/controller/Chat2dbController.java b/src/main/java/com/example/springaidemo/controller/Chat2dbController.java new file mode 100644 index 0000000..205f3f4 --- /dev/null +++ b/src/main/java/com/example/springaidemo/controller/Chat2dbController.java @@ -0,0 +1,25 @@ +package com.example.springaidemo.controller; + +import com.example.springaidemo.bean.ChatRequest; +import com.example.springaidemo.service.Chat2dbService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.util.Map; + +@RestController +@RequestMapping("/Chat2dbController") +public class Chat2dbController { + + @Autowired + private Chat2dbService chat2dbService; + + @PostMapping("/chat2db") + public SseEmitter generate(@RequestBody ChatRequest chatRequest) { + return chat2dbService.chat2db(chatRequest.getMessage()); + } +} diff --git a/src/main/java/com/example/springaidemo/controller/ChatController.java b/src/main/java/com/example/springaidemo/controller/ChatController.java index 98f12af..9dc2928 100644 --- a/src/main/java/com/example/springaidemo/controller/ChatController.java +++ b/src/main/java/com/example/springaidemo/controller/ChatController.java @@ -21,14 +21,14 @@ public class ChatController { @Autowired private DeepseekChatService deepseekChatService; - @GetMapping("/ai/chat") - public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return deepseekChatService.chat(message); + @PostMapping("/ai/chat") + public Map generate(@RequestBody ChatRequest chatRequest) { + return deepseekChatService.chat(chatRequest.getMessage()); } @GetMapping("/ai/generateStream") - public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return deepseekChatService.streamChat(message); + public Flux generateStream(@RequestBody ChatRequest chatRequest) { + return deepseekChatService.streamChat(chatRequest.getMessage()); } @PostMapping(value = "/ai/sseChat", produces = MediaType.TEXT_EVENT_STREAM_VALUE) diff --git a/src/main/java/com/example/springaidemo/controller/RagController.java b/src/main/java/com/example/springaidemo/controller/RagController.java index 947fc77..88f92ab 100644 --- a/src/main/java/com/example/springaidemo/controller/RagController.java +++ b/src/main/java/com/example/springaidemo/controller/RagController.java @@ -1,10 +1,16 @@ package com.example.springaidemo.controller; +import com.example.springaidemo.bean.ChatRequest; import com.example.springaidemo.service.VectorService; +import org.springframework.ai.document.Document; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; +import java.util.List; +import java.util.Map; + @RestController @RequestMapping("/RagController") public class RagController { @@ -18,4 +24,16 @@ public class RagController { return "success"; } + @RequestMapping("/modelStore") + public String modelVector() { + vectorService.modelVector(); + return "success"; + } + + @RequestMapping("/searchModel") + public String searchModel(@RequestBody ChatRequest chatRequest) { + List documents = vectorService.searchModel(chatRequest.getMessage(), 3); + return "success"; + } + } diff --git a/src/main/java/com/example/springaidemo/service/Chat2dbService.java b/src/main/java/com/example/springaidemo/service/Chat2dbService.java new file mode 100644 index 0000000..a65413d --- /dev/null +++ b/src/main/java/com/example/springaidemo/service/Chat2dbService.java @@ -0,0 +1,59 @@ +package com.example.springaidemo.service; + +import cn.hutool.json.JSONObject; +import cn.hutool.json.JSONUtil; +import org.springframework.ai.deepseek.DeepSeekChatModel; +import org.springframework.ai.document.Document; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.util.*; + +@Service +public class Chat2dbService { + + private final DeepSeekChatModel chatModel; + + @Autowired + private VectorService vectorService; + + @Autowired + private DeepseekChatService deepseekChatService; + + @Autowired + public Chat2dbService(DeepSeekChatModel chatModel) { + this.chatModel = chatModel; + } + + private Map modelMap = Map.of("tb_user", "{tableName:tb_user, tableDesc:用户表, " + + "cols[{fieldName:user_name,fieldDesc:用户名},{fieldName:user_sex,fieldDesc:用户性别},{fieldName:user_age,fieldDesc:用户年龄}" + + ",{fieldName:uuser_phone,fieldDesc:用户手机号},{fieldName:user_email,fieldDesc:用户邮箱},{fieldName:user_address,fieldDesc:用户地址}]}"); + + public SseEmitter chat2db(String question) { + String prompt = "请你从以下句子'" + question + "'中去分别去识别出可能是表名、字段名的信息. " + + "并按照以下json格式回复(不要用md格式包裹):{tables:['table1','table2'..],cols:['col1','col2'...]}"; + String generationJson = chatModel.call(prompt); + JSONObject jsonObject = JSONUtil.parseObj(generationJson); + List tableList = jsonObject.get("tables", List.class); + List colsList = jsonObject.get("cols", List.class); + Set tableSet = new HashSet<>(); + for (String table : tableList) { + for (String col : colsList) { + String vectorSearch = "表名 " + table + "|字段名 " + col; + List documents = vectorService.searchModel(vectorSearch,1); + if (documents.size() > 0) { + Map metadata = documents.get(0).getMetadata(); + String tableName = metadata.get("modelName").toString(); + tableSet.add(tableName); + } + } + } + StringBuilder sb = new StringBuilder(); + for (String tableName : tableSet) { + sb.append(modelMap.get(tableName)==null ? "" : modelMap.get(tableName).toString()); + } + String sqlPrompt = "请你根据以下的描述" + question + "去写一个sql语句, 你需要根据以下表与字段相关信息" + sb.toString() + "去生成, 如果没有提供表与字段相关信息, 则返回'没有找到相关的表和字段,请你提供更相信的信息'"; + return deepseekChatService.sseChat(sqlPrompt); + } +} diff --git a/src/main/java/com/example/springaidemo/service/VectorService.java b/src/main/java/com/example/springaidemo/service/VectorService.java index 1c449eb..088aba1 100644 --- a/src/main/java/com/example/springaidemo/service/VectorService.java +++ b/src/main/java/com/example/springaidemo/service/VectorService.java @@ -6,7 +6,9 @@ import org.springframework.ai.vectorstore.VectorStore; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import java.util.HashMap; import java.util.List; +import java.util.Map; @Service public class VectorService { @@ -29,4 +31,29 @@ public class VectorService { List results = this.vectorStore.similaritySearch(searchRequest); System.out.println(results); } + + public void modelVector() { + // 创建包含自定义字段的元数据 + Map metadata = new HashMap<>(); + metadata.put("modelName", "tb_vip_user"); + metadata.put("source", "dataCommon"); + metadata.put("timestamp", System.currentTimeMillis()); + List documents = List.of( + new Document("表 tb_vip_user VIP用户表 | 字段 user_name用户名", metadata), + new Document("表 tb_vip_user VIP用户表 | 字段 user_id用户id", metadata), + new Document("表 tb_vip_user VIP用户表 | 字段 user_sex用户性别", metadata), + new Document("表 tb_vip_user VIP用户表 | 字段 user_age用户年龄", metadata), + new Document("表 tb_vip_user VIP用户表 | 字段 user_phone用户手机号", metadata), + new Document("表 tb_vip_user VIP用户表 | 字段 user_email用户邮箱", metadata), + new Document("表 tb_vip_user VIP用户表 | 字段 user_address用户地址", metadata)); + vectorStore.add(documents); + } + + public List searchModel(String query, int topK) { + return vectorStore.similaritySearch(SearchRequest.builder() + .query(query) + .topK(topK) +// .filterExpression("source == 'dataCommon'") + .build()); + } }