Переглянути джерело

【优化】兼容阿里云千问开源模型,和付费模型

cherishsince 1 рік тому
батько
коміт
63c5f90596

+ 44 - 23
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenChatClient.java

@@ -4,13 +4,14 @@ import cn.iocoder.yudao.framework.ai.chat.*;
 import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
 import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
 import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi;
-import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenChatCompletionRequest;
 import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException;
 import com.alibaba.dashscope.aigc.generation.GenerationResult;
 import com.alibaba.dashscope.aigc.generation.models.QwenParam;
 import com.alibaba.dashscope.common.Message;
+import com.google.common.collect.Lists;
 import io.reactivex.Flowable;
 import lombok.extern.slf4j.Slf4j;
+import org.jetbrains.annotations.NotNull;
 import org.springframework.http.ResponseEntity;
 import org.springframework.retry.RetryCallback;
 import org.springframework.retry.RetryContext;
@@ -71,7 +72,7 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient {
         return this.retryTemplate.execute(ctx -> {
             // ctx 会有重试的信息
             // 创建 request 请求,stream模式需要供应商支持
-            QianWenChatCompletionRequest request = this.createRequest(prompt, false);
+            QwenParam request = this.createRequest(prompt, false);
             // 调用 callWithFunctionSupport 发送请求
             ResponseEntity<GenerationResult> responseEntity = qianWenApi.chatCompletionEntity(request);
             // 获取结果封装 chatCompletion
@@ -81,11 +82,41 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient {
 //                        response.getRequestId(), response.getCode(), response.getMessage()))));
 //            }
             // 转换为 Generation 返回
-            return new ChatResponse(List.of(new Generation(response.getOutput().getText())));
+            return new ChatResponse(response.getOutput().getChoices().stream()
+                    .map(choices -> new Generation(choices.getMessage().getContent()))
+                    .collect(Collectors.toList()));
         });
     }
 
-    private QianWenChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
+    private QwenParam createRequest(Prompt prompt, boolean stream) {
+        // 获取 ChatOptions
+        QianWenOptions chatOptions = getChatOptions(prompt);
+        //
+        List<Message> messageList = Lists.newArrayList();
+        prompt.getInstructions().stream().forEach(instruction -> {
+            Message message = new Message();
+            message.setRole(instruction.getMessageType().getValue());
+            message.setContent(instruction.getContent());
+            messageList.add(message);
+        });
+        return QwenParam.builder()
+                .model(qianWenApi.getQianWenChatModal().getValue())
+                .prompt(prompt.getContents())
+                .messages(messageList)
+                .maxTokens(chatOptions.getMaxTokens())
+                .resultFormat(QwenParam.ResultFormat.MESSAGE)
+                .topP(Double.valueOf(chatOptions.getTopP()))
+                .topK(chatOptions.getTopK())
+                .temperature(chatOptions.getTemperature())
+                // 控制流式输出模式,即后面的内容会包含已经输出的内容;设置为True,将开启增量输出模式,后面的输出不会包含已经输出的内容,您需要自行拼接整体输出
+                .incrementalOutput(true)
+                /* set the random seed, optional, default to 1234 if not set */
+                .seed(100)
+                .apiKey(qianWenApi.getApiKey())
+                .build();
+    }
+
+    private @NotNull QianWenOptions getChatOptions(Prompt prompt) {
         // 两个都为null 则没有配置文件
         if (qianWenOptions == null && prompt.getOptions() == null) {
             throw new ChatException("ChatOptions 未配置参数!");
@@ -96,37 +127,27 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient {
             options = (ChatOptions) prompt.getOptions();
         }
         // Prompt 里面是一个 ChatOptions,用户可以随意传入,这里做一下判断
-        if (!(options instanceof QianWenOptions qianWenOptions)) {
+        if (!(options instanceof QianWenOptions)) {
             throw new ChatException("Prompt 传入的不是 QianWenOptions!");
         }
-        return (QianWenChatCompletionRequest) QianWenChatCompletionRequest.builder()
-                .model(qianWenApi.getQianWenChatModal().getValue())
-                .apiKey(qianWenApi.getApiKey())
-                .messages(prompt.getInstructions().stream().map(m -> {
-                    Message message = new Message();
-                    message.setRole(m.getMessageType().getValue());
-                    message.setContent(m.getContent());
-                    return message;
-                }).collect(Collectors.toList()))
-                .resultFormat(QwenParam.ResultFormat.MESSAGE)
-                // 动态改变的三个参数
-                .topP(Double.valueOf(qianWenOptions.getTopP()))
-                .topK(qianWenOptions.getTopK())
-                .temperature(qianWenOptions.getTemperature())
-                .incrementalOutput(true)
-                .build();
+        return (QianWenOptions) options;
     }
 
     @Override
     public Flux<ChatResponse> stream(Prompt prompt) {
         // ctx 会有重试的信息
         // 创建 request 请求,stream模式需要供应商支持
-        QianWenChatCompletionRequest request = this.createRequest(prompt, true);
+        QwenParam request = this.createRequest(prompt, true);
         // 调用 callWithFunctionSupport 发送请求
         Flowable<GenerationResult> responseResult = this.qianWenApi.chatCompletionStream(request);
+
         return Flux.create(fluxSink ->
                 responseResult.subscribe(
-                        value -> fluxSink.next(new ChatResponse(List.of(new Generation(value.getOutput().getText())))),
+                        value -> fluxSink.next(
+                                new ChatResponse(value.getOutput().getChoices().stream()
+                                        .map(choices -> new Generation(choices.getMessage().getContent()))
+                                        .collect(Collectors.toList()))
+                        ),
                         error -> fluxSink.error(error),
                         () -> fluxSink.complete()
                 )

+ 5 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenOptions.java

@@ -15,7 +15,7 @@ import java.util.List;
  * time: 2024/3/15 19:57
  */
 @Data
-@Accessors
+@Accessors(chain = true)
 public class QianWenOptions implements ChatOptions {
 
     /**
@@ -28,6 +28,10 @@ public class QianWenOptions implements ChatOptions {
      * 默认值为0.8。注意,取值不要大于等于1
      */
     private Float topP;
+    /**
+     * 用于限制模型生成token的数量,max_tokens设置的是生成上限,并不表示一定会生成这么多的token数量。其中qwen1.5-14b-chat、qwen1.5-7b-chat、qwen-14b-chat和qwen-7b-chat最大值和默认值均为1500,qwen-1.8b-chat、qwen-1.8b-longcontext-chat和qwen-72b-chat最大值和默认值均为2000
+     */
+    private Integer maxTokens = 1500;
 
     //
     // 适配 ChatOptions

+ 3 - 4
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/api/QianWenApi.java

@@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatModal;
 import cn.iocoder.yudao.framework.ai.exception.AiException;
 import com.alibaba.dashscope.aigc.generation.Generation;
 import com.alibaba.dashscope.aigc.generation.GenerationResult;
+import com.alibaba.dashscope.aigc.generation.models.QwenParam;
 import com.alibaba.dashscope.common.Message;
 import com.alibaba.dashscope.common.Role;
 import com.alibaba.dashscope.exception.InputRequiredException;
@@ -34,9 +35,7 @@ public class QianWenApi {
         this.qianWenChatModal = qianWenChatModal;
     }
 
-    public ResponseEntity<GenerationResult> chatCompletionEntity(QianWenChatCompletionRequest request) {
-        Message userMsg = Message.builder().role(Role.USER.getValue()).content("用萝卜、土豆、茄子做饭,给我个菜谱").build();
-
+    public ResponseEntity<GenerationResult> chatCompletionEntity(QwenParam request) {
         GenerationResult call;
         try {
             call = gen.call(request);
@@ -49,7 +48,7 @@ public class QianWenApi {
         return new ResponseEntity<>(call, HttpStatusCode.valueOf(200));
     }
 
-    public Flowable<GenerationResult> chatCompletionStream(QianWenChatCompletionRequest request) {
+    public Flowable<GenerationResult> chatCompletionStream(QwenParam request) {
         Flowable<GenerationResult> resultFlowable;
         try {
             resultFlowable = gen.streamCall(request);

+ 1 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java

@@ -47,6 +47,7 @@ public class YudaoAiAutoConfiguration {
         QianWenOptions qianWenOptions = new QianWenOptions();
         qianWenOptions.setTopK(qianWenProperties.getTopK());
         qianWenOptions.setTopP(qianWenProperties.getTopP());
+        qianWenOptions.setMaxTokens(qianWenProperties.getMaxTokens());
         qianWenOptions.setTemperature(qianWenProperties.getTemperature());
         return new QianWenChatClient(
                 new QianWenApi(

+ 4 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java

@@ -47,6 +47,10 @@ public class YudaoAiProperties {
          * api key
          */
         private String apiKey;
+        /**
+         * 用于限制模型生成token的数量,max_tokens设置的是生成上限,并不表示一定会生成这么多的token数量
+         */
+        private Integer maxTokens;
         /**
          * model
          */

+ 57 - 11
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/QianWenChatClientTests.java

@@ -1,13 +1,25 @@
 package cn.iocoder.yudao.framework.ai.chat;
 
+import cn.iocoder.yudao.framework.ai.chat.messages.SystemMessage;
+import cn.iocoder.yudao.framework.ai.chat.messages.UserMessage;
 import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
-import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi;
 import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
+import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatModal;
 import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenOptions;
+import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi;
+import com.alibaba.dashscope.aigc.generation.GenerationResult;
+import com.alibaba.dashscope.aigc.generation.models.QwenParam;
+import com.alibaba.dashscope.common.Message;
+import com.alibaba.dashscope.common.MessageManager;
+import com.alibaba.dashscope.common.Role;
+import com.alibaba.dashscope.exception.InputRequiredException;
+import com.alibaba.dashscope.exception.NoApiKeyException;
 import org.junit.Before;
 import org.junit.Test;
 import reactor.core.publisher.Flux;
 
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Scanner;
 import java.util.function.Consumer;
 
@@ -21,28 +33,34 @@ public class QianWenChatClientTests {
 
     @Before
     public void setup() {
-        QianWenApi qianWenApi = new QianWenApi(
-                "LTAI5tNTVhXW4fLKUjMrr98z",
-                "ZJ0JQeyjzxxm5CfeTV6k1wNE9UsvZP",
-                "f0c1088824594f589c8f10567ccd929f_p_efm",
-                null
-        );
+        QianWenApi qianWenApi = new QianWenApi("sk-Zsd81gZYg7", QianWenChatModal.QWEN_72B_CHAT);
+        QianWenOptions qianWenOptions = new QianWenOptions();
+        qianWenOptions.setTopP(0.8F);
+        qianWenOptions.setTopK(3);
+        qianWenOptions.setTemperature(0.6F);
         qianWenChatClient = new QianWenChatClient(
                 qianWenApi,
-                new QianWenOptions()
-                        .setAppId("5f14955f201a44eb8dbe0c57250a32ce")
+                qianWenOptions
         );
     }
 
     @Test
     public void callTest() {
-        ChatResponse call = qianWenChatClient.call(new Prompt("Java语言怎么样?"));
+        List<cn.iocoder.yudao.framework.ai.chat.messages.Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的小红书文艺作者,抒写着各城市的美好文化和风景。"));
+        messages.add(new UserMessage("长沙怎么样?"));
+
+        ChatResponse call = qianWenChatClient.call(new Prompt(messages));
         System.err.println(call.getResult());
     }
 
     @Test
     public void streamTest() {
-        Flux<ChatResponse> flux = qianWenChatClient.stream(new Prompt("Java语言怎么样?"));
+        List<cn.iocoder.yudao.framework.ai.chat.messages.Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+        messages.add(new UserMessage("长沙怎么样?"));
+
+        Flux<ChatResponse> flux = qianWenChatClient.stream(new Prompt(messages));
         flux.subscribe(new Consumer<ChatResponse>() {
             @Override
             public void accept(ChatResponse chatResponse) {
@@ -54,4 +72,32 @@ public class QianWenChatClientTests {
         Scanner scanner = new Scanner(System.in);
         scanner.nextLine();
     }
+
+    @Test
+    public void qianwenDemoTest() throws NoApiKeyException, InputRequiredException {
+        com.alibaba.dashscope.aigc.generation.Generation gen = new com.alibaba.dashscope.aigc.generation.Generation();
+        MessageManager msgManager = new MessageManager(10);
+        Message systemMsg =
+                Message.builder().role(Role.SYSTEM.getValue()).content("You are a helpful assistant.").build();
+        Message userMsg = Message.builder().role(Role.USER.getValue()).content("就当前的海洋污染的情况,写一份限塑的倡议书提纲,需要有理有据地号召大家克制地使用塑料制品").build();
+        msgManager.add(systemMsg);
+        msgManager.add(userMsg);
+        QwenParam param =
+                QwenParam.builder().model("qwen-72b-chat").messages(msgManager.get())
+                        .resultFormat(QwenParam.ResultFormat.MESSAGE)
+                        .topP(0.8)
+                        /* set the random seed, optional, default to 1234 if not set */
+                        .seed(100)
+                        .apiKey("sk-Zsd81gZYg7")
+                        .build();
+        GenerationResult result = gen.call(param);
+        System.out.println(result);
+        System.out.println("-----------------");
+        System.out.println("-----------------");
+        msgManager.add(result);
+        param.setPrompt("能否缩短一些,只讲三点");
+        param.setMessages(msgManager.get());
+        result = gen.call(param);
+        System.out.println(result);
+    }
 }

+ 2 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/midjourney/MidjourneyInteractionsTests.java

@@ -6,7 +6,7 @@ import cn.iocoder.yudao.framework.ai.midjourney.api.req.AttachmentsReq;
 import cn.iocoder.yudao.framework.ai.midjourney.api.req.DescribeReq;
 import cn.iocoder.yudao.framework.ai.midjourney.api.req.ReRollReq;
 import cn.iocoder.yudao.framework.ai.midjourney.api.res.UploadAttachmentsRes;
-import com.alibaba.fastjson.JSON;
+import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
 import org.junit.Before;
 import org.junit.Test;
 import org.springframework.core.io.FileSystemResource;
@@ -58,7 +58,7 @@ public class MidjourneyInteractionsTests {
                 new AttachmentsReq().setFileSystemResource(
                         new FileSystemResource(new File("/Users/fansili/Downloads/DSC01402.JPG")))
         );
-        System.err.println(JSON.toJSONString(res));
+        System.err.println(JsonUtils.toJsonString(res));
     }
 
     @Test

+ 5 - 8
yudao-server/src/main/resources/application-local.yaml

@@ -228,14 +228,11 @@ yudao:
     qianwen:
       enable: true
       aiPlatform: QIAN_WEN
-      temperature: 1
-      topP: 1
-      topK: 1
-      endpoint: bailian.cn-beijing.aliyuncs.com
-      accessKeyId: LTAI5tNTVhXW4fLKUjMrr98z
-      accessKeySecret: ZJ0JQeyjzxxm5CfeTV6k1wNE9UsvZP
-      agentKey: f0c1088824594f589c8f10567ccd929f_p_efm
-      appId: 5f14955f201a44eb8dbe0c57250a32ce
+      temperature: 0.85
+      topP: 0.8
+      topK: 0
+      api-key: sk-Zsd81gZYg7
+      max-tokens: 1500
     xinghuo:
       enable: true
       aiPlatform: XING_HUO