Browse Source

1、增加 chat service
2、可动态传入 modal,选择模型

cherishsince 1 year ago
parent
commit
ef701167b7

+ 9 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiClientNameEnum.java

@@ -25,4 +25,13 @@ public enum AiClientNameEnum {
     private String name;
     private String name;
 
 
     private String message;
     private String message;
+
+    public static AiClientNameEnum valueOfName(String name) {
+        for (AiClientNameEnum nameEnum : AiClientNameEnum.values()) {
+            if (nameEnum.getName().equals(name)) {
+                return nameEnum;
+            }
+        }
+        throw new IllegalArgumentException("Invalid MessageType value: " + name);
+    }
 }
 }

+ 7 - 8
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/ChatController.java

@@ -2,20 +2,20 @@ package cn.iocoder.yudao.module.ai.controller;
 
 
 import cn.hutool.core.exceptions.ExceptionUtil;
 import cn.hutool.core.exceptions.ExceptionUtil;
 import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
 import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
-import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
 import cn.iocoder.yudao.framework.ai.config.AiClient;
 import cn.iocoder.yudao.framework.ai.config.AiClient;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
-import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
 import cn.iocoder.yudao.module.ai.service.ChatService;
 import cn.iocoder.yudao.module.ai.service.ChatService;
+import cn.iocoder.yudao.module.ai.vo.ChatReq;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.tags.Tag;
 import io.swagger.v3.oas.annotations.tags.Tag;
 import lombok.AllArgsConstructor;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.http.MediaType;
 import org.springframework.http.MediaType;
+import org.springframework.validation.annotation.Validated;
 import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.ModelAttribute;
 import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RequestMapping;
-import org.springframework.web.bind.annotation.RequestParam;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 import reactor.core.publisher.Flux;
 import reactor.core.publisher.Flux;
@@ -43,17 +43,16 @@ public class ChatController {
 
 
     @Operation(summary = "聊天-chat", description = "这个一般等待时间比较久,需要全部完成才会返回!")
     @Operation(summary = "聊天-chat", description = "这个一般等待时间比较久,需要全部完成才会返回!")
     @GetMapping("/chat")
     @GetMapping("/chat")
-    public CommonResult<String> chat(@RequestParam("prompt") String prompt) {
-        ChatResponse callRes = aiClient.call(new Prompt(prompt), AiClientNameEnum.QIAN_WEN.getName());
-        return CommonResult.success(callRes.getResult().getOutput().getContent());
+    public CommonResult<String> chat(@Validated @ModelAttribute ChatReq req) {
+        return CommonResult.success(chatService.chat(req));
     }
     }
 
 
     // TODO @芋艿:调用这个方法异常,Unable to handle the Spring Security Exception because the response is already committed.
     // TODO @芋艿:调用这个方法异常,Unable to handle the Spring Security Exception because the response is already committed.
     @Operation(summary = "聊天-stream", description = "这里跟通义千问一样采用的是 Server-Sent Events (SSE) 通讯模式")
     @Operation(summary = "聊天-stream", description = "这里跟通义千问一样采用的是 Server-Sent Events (SSE) 通讯模式")
     @GetMapping(value = "/chatStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
     @GetMapping(value = "/chatStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
-    public SseEmitter chatStream(@RequestParam("prompt") String prompt) {
+    public SseEmitter chatStream(@Validated @ModelAttribute ChatReq req) {
         Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
         Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
-        Flux<ChatResponse> streamResponse = aiClient.stream(new Prompt(prompt), AiClientNameEnum.QIAN_WEN.getName());
+        Flux<ChatResponse> streamResponse = chatService.chatStream(req);
         streamResponse.subscribe(
         streamResponse.subscribe(
                 new Consumer<ChatResponse>() {
                 new Consumer<ChatResponse>() {
                     @Override
                     @Override

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/AiChatRoleController.java → yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/ChatRoleController.java

@@ -17,7 +17,7 @@ import org.springframework.web.bind.annotation.*;
 @RestController
 @RestController
 @RequestMapping("/chat-role")
 @RequestMapping("/chat-role")
 @AllArgsConstructor
 @AllArgsConstructor
-public class AiChatRoleController {
+public class ChatRoleController {
 
 
     private final ChatRoleService chatRoleService;
     private final ChatRoleService chatRoleService;
 
 

+ 33 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/ChatService.java

@@ -0,0 +1,33 @@
+package cn.iocoder.yudao.module.ai.service;
+
+import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
+import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
+import cn.iocoder.yudao.module.ai.vo.ChatReq;
+import reactor.core.publisher.Flux;
+
+/**
+ * 聊天 chat
+ *
+ * @author fansili
+ * @time 2024/4/14 15:55
+ * @since 1.0
+ */
+public interface ChatService {
+
+    /**
+     * chat
+     *
+     * @param req
+     * @return
+     */
+    String chat(ChatReq req);
+
+    /**
+     * chat stream
+     *
+     * @param req
+     * @param clientNameEnum
+     * @return
+     */
+    Flux<ChatResponse> chatStream(ChatReq req);
+}

+ 62 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java

@@ -0,0 +1,62 @@
+package cn.iocoder.yudao.module.ai.service.impl;
+
+import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
+import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
+import cn.iocoder.yudao.framework.ai.config.AiClient;
+import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
+import cn.iocoder.yudao.module.ai.service.ChatService;
+import cn.iocoder.yudao.module.ai.vo.ChatReq;
+import lombok.AllArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.stereotype.Service;
+import reactor.core.publisher.Flux;
+
+/**
+ * 聊天 service
+ *
+ * @author fansili
+ * @time 2024/4/14 15:55
+ * @since 1.0
+ */
+@Slf4j
+@Service
+@AllArgsConstructor
+public class ChatServiceImpl implements ChatService {
+
+    private final AiClient aiClient;
+
+    /**
+     * chat
+     *
+     * @param req
+     * @return
+     */
+    public String chat(ChatReq req) {
+        AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
+        // 创建 chat 需要的 Prompt
+        Prompt prompt = new Prompt(req.getPrompt());
+        req.setTopK(req.getTopK());
+        req.setTopP(req.getTopP());
+        req.setTemperature(req.getTemperature());
+        // 发送 call 调用
+        ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
+        return call.getResult().getOutput().getContent();
+    }
+
+    /**
+     * chat stream
+     *
+     * @param req
+     * @return
+     */
+    @Override
+    public Flux<ChatResponse> chatStream(ChatReq req) {
+        AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
+        // 创建 chat 需要的 Prompt
+        Prompt prompt = new Prompt(req.getPrompt());
+        req.setTopK(req.getTopK());
+        req.setTopP(req.getTopP());
+        req.setTemperature(req.getTemperature());
+        return aiClient.stream(prompt, clientNameEnum.getName());
+    }
+}

+ 42 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/vo/ChatReq.java

@@ -0,0 +1,42 @@
+package cn.iocoder.yudao.module.ai.vo;
+
+import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
+import io.swagger.v3.oas.annotations.media.Schema;
+import jakarta.validation.constraints.NotNull;
+import jakarta.validation.constraints.Size;
+import lombok.Data;
+import lombok.experimental.Accessors;
+
+/**
+ * chat req
+ *
+ * @author fansili
+ * @time 2024/4/14 16:12
+ * @since 1.0
+ */
+@Data
+@Accessors(chain = true)
+public class ChatReq {
+
+
+    @NotNull(message = "提示词不能为空!")
+    @Size(max = 3000, message = "提示词最大3000个字符!")
+    @Schema(description = "填入固定值,1 issues, 2 pr")
+    private String prompt;
+
+    @Schema(description = "用于控制随机性和多样性的温度参数")
+    private Float temperature;
+
+    @Schema(description = "生成时,核采样方法的概率阈值。例如,取值为0.8时,仅保留累计概率之和大于等于0.8的概率分布中的token,\n" +
+            "     * 作为随机采样的候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的随机性越低。\n" +
+            "     * 默认值为0.8。注意,取值不要大于等于1\n")
+    private Float topP;
+
+    @Schema(description = "在生成消息时采用的Top-K采样大小,表示模型生成回复时考虑的候选项集合的大小")
+    private Integer topK;
+
+    @Schema(description = "ai模型(查看 AiClientNameEnum)")
+    @NotNull(message = "模型不能为空!")
+    @Size(max = 30, message = "模型字符最大30个字符!")
+    private String modal;
+}