20250728 sql助手简单实现(通过大模型识别出表、字段信息,然后通过向量召回表信息,再根据表信息生成sql)

This commit is contained in:
liangjinglin 2025-07-28 22:23:40 +08:00
parent c1f0f4766d
commit 65b2f93fc2
5 changed files with 134 additions and 5 deletions

View File

@ -0,0 +1,25 @@
package com.example.springaidemo.controller;
import com.example.springaidemo.bean.ChatRequest;
import com.example.springaidemo.service.Chat2dbService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Map;
@RestController
@RequestMapping("/Chat2dbController")
public class Chat2dbController {
@Autowired
private Chat2dbService chat2dbService;
@PostMapping("/chat2db")
public SseEmitter generate(@RequestBody ChatRequest chatRequest) {
return chat2dbService.chat2db(chatRequest.getMessage());
}
}

View File

@ -21,14 +21,14 @@ public class ChatController {
@Autowired
private DeepseekChatService deepseekChatService;
@GetMapping("/ai/chat")
public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) {
return deepseekChatService.chat(message);
@PostMapping("/ai/chat")
public Map generate(@RequestBody ChatRequest chatRequest) {
return deepseekChatService.chat(chatRequest.getMessage());
}
@GetMapping("/ai/generateStream")
public Flux<ChatResponse> generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) {
return deepseekChatService.streamChat(message);
public Flux<ChatResponse> generateStream(@RequestBody ChatRequest chatRequest) {
return deepseekChatService.streamChat(chatRequest.getMessage());
}
@PostMapping(value = "/ai/sseChat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)

View File

@ -1,10 +1,16 @@
package com.example.springaidemo.controller;
import com.example.springaidemo.bean.ChatRequest;
import com.example.springaidemo.service.VectorService;
import org.springframework.ai.document.Document;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
import java.util.Map;
@RestController
@RequestMapping("/RagController")
public class RagController {
@ -18,4 +24,16 @@ public class RagController {
return "success";
}
@RequestMapping("/modelStore")
public String modelVector() {
vectorService.modelVector();
return "success";
}
@RequestMapping("/searchModel")
public String searchModel(@RequestBody ChatRequest chatRequest) {
List<Document> documents = vectorService.searchModel(chatRequest.getMessage(), 3);
return "success";
}
}

View File

@ -0,0 +1,59 @@
package com.example.springaidemo.service;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.document.Document;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.*;
@Service
public class Chat2dbService {
private final DeepSeekChatModel chatModel;
@Autowired
private VectorService vectorService;
@Autowired
private DeepseekChatService deepseekChatService;
@Autowired
public Chat2dbService(DeepSeekChatModel chatModel) {
this.chatModel = chatModel;
}
private Map<String, Object> modelMap = Map.of("tb_user", "{tableName:tb_user, tableDesc:用户表, " +
"cols[{fieldName:user_name,fieldDesc:用户名},{fieldName:user_sex,fieldDesc:用户性别},{fieldName:user_age,fieldDesc:用户年龄}" +
",{fieldName:uuser_phone,fieldDesc:用户手机号},{fieldName:user_email,fieldDesc:用户邮箱},{fieldName:user_address,fieldDesc:用户地址}]}");
public SseEmitter chat2db(String question) {
String prompt = "请你从以下句子'" + question + "'中去分别去识别出可能是表名、字段名的信息. " +
"并按照以下json格式回复(不要用md格式包裹):{tables:['table1','table2'..],cols:['col1','col2'...]}";
String generationJson = chatModel.call(prompt);
JSONObject jsonObject = JSONUtil.parseObj(generationJson);
List<String> tableList = jsonObject.get("tables", List.class);
List<String> colsList = jsonObject.get("cols", List.class);
Set<String> tableSet = new HashSet<>();
for (String table : tableList) {
for (String col : colsList) {
String vectorSearch = "表名 " + table + "|字段名 " + col;
List<Document> documents = vectorService.searchModel(vectorSearch,1);
if (documents.size() > 0) {
Map<String, Object> metadata = documents.get(0).getMetadata();
String tableName = metadata.get("modelName").toString();
tableSet.add(tableName);
}
}
}
StringBuilder sb = new StringBuilder();
for (String tableName : tableSet) {
sb.append(modelMap.get(tableName)==null ? "" : modelMap.get(tableName).toString());
}
String sqlPrompt = "请你根据以下的描述" + question + "去写一个sql语句, 你需要根据以下表与字段相关信息" + sb.toString() + "去生成, 如果没有提供表与字段相关信息, 则返回'没有找到相关的表和字段,请你提供更相信的信息'";
return deepseekChatService.sseChat(sqlPrompt);
}
}

View File

@ -6,7 +6,9 @@ import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Service
public class VectorService {
@ -29,4 +31,29 @@ public class VectorService {
List<Document> results = this.vectorStore.similaritySearch(searchRequest);
System.out.println(results);
}
public void modelVector() {
// 创建包含自定义字段的元数据
Map<String, Object> metadata = new HashMap<>();
metadata.put("modelName", "tb_vip_user");
metadata.put("source", "dataCommon");
metadata.put("timestamp", System.currentTimeMillis());
List<Document> documents = List.of(
new Document("表 tb_vip_user VIP用户表 | 字段 user_name用户名", metadata),
new Document("表 tb_vip_user VIP用户表 | 字段 user_id用户id", metadata),
new Document("表 tb_vip_user VIP用户表 | 字段 user_sex用户性别", metadata),
new Document("表 tb_vip_user VIP用户表 | 字段 user_age用户年龄", metadata),
new Document("表 tb_vip_user VIP用户表 | 字段 user_phone用户手机号", metadata),
new Document("表 tb_vip_user VIP用户表 | 字段 user_email用户邮箱", metadata),
new Document("表 tb_vip_user VIP用户表 | 字段 user_address用户地址", metadata));
vectorStore.add(documents);
}
public List<Document> searchModel(String query, int topK) {
return vectorStore.similaritySearch(SearchRequest.builder()
.query(query)
.topK(topK)
// .filterExpression("source == 'dataCommon'")
.build());
}
}