瀏覽代碼

【优化】处理百度 system 角色定制失效问题。

cherishsince 1 年之前
父節點
當前提交
10a94c3ef2

+ 46 - 16
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanChatClient.java

@@ -2,6 +2,8 @@ package cn.iocoder.yudao.framework.ai.chatyiyan;
 
 import cn.hutool.core.bean.BeanUtil;
 import cn.iocoder.yudao.framework.ai.chat.*;
+import cn.iocoder.yudao.framework.ai.chat.messages.Message;
+import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
 import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
 import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
 import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi;
@@ -9,6 +11,7 @@ import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletion;
 import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest;
 import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException;
 import lombok.extern.slf4j.Slf4j;
+import org.jetbrains.annotations.NotNull;
 import org.springframework.http.ResponseEntity;
 import org.springframework.retry.RetryCallback;
 import org.springframework.retry.RetryContext;
@@ -18,10 +21,11 @@ import reactor.core.publisher.Flux;
 
 import java.time.Duration;
 import java.util.List;
+import java.util.stream.Collectors;
 
 /**
  * 文心一言
- *
+ * <p>
  * author: fansili
  * time: 2024/3/8 19:11
  */
@@ -52,7 +56,9 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
                 public <T extends Object, E extends Throwable> void onError(RetryContext context,
                                                                             RetryCallback<T, E> callback, Throwable throwable) {
                     log.warn("重试异常:" + context.getRetryCount(), throwable);
-                };
+                }
+
+                ;
             })
             .build();
 
@@ -92,6 +98,42 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
     }
 
     private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
+        // 获取配置
+        YiYanOptions useOptions = getYiYanOptions(prompt);
+        // 创建 request
+
+        // tip: 百度的 system 不在 message 里面
+        // tip:百度的 message 只有 user 和 assistant
+        // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
+
+        // 获取 user 和 assistant
+        List<YiYanChatCompletionRequest.Message> messageList = prompt.getInstructions().stream()
+                // 过滤 system
+                .filter(msg -> MessageType.SYSTEM != msg.getMessageType())
+                .map(msg -> new YiYanChatCompletionRequest.Message()
+                        .setRole(msg.getMessageType().getValue())
+                        .setContent(msg.getContent())
+                ).toList();
+        // 获取 system
+        String systemPrompt = prompt.getInstructions().stream()
+                .filter(msg -> MessageType.SYSTEM == msg.getMessageType())
+                .map(Message::getContent)
+                .collect(Collectors.joining());
+
+        YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList);
+        // 复制 qianWenOptions 属性取 request(这里 options 属性和 request 基本保持一致)
+        // top: 由于遵循 spring-ai规范,支持在构建client的时候传入默认的 chatOptions
+        BeanUtil.copyProperties(useOptions, request);
+        request.setTop_p(useOptions.getTopP());
+        request.setMax_output_tokens(useOptions.getMaxOutputTokens());
+        request.setTemperature(useOptions.getTemperature());
+        request.setSystem(systemPrompt);
+        // 设置 stream
+        request.setStream(stream);
+        return request;
+    }
+
+    private @NotNull YiYanOptions getYiYanOptions(Prompt prompt) {
         // 两个都为null 则没有配置文件
         if (yiYanOptions == null && prompt.getOptions() == null) {
             throw new ChatException("ChatOptions 未配置参数!");
@@ -106,19 +148,7 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
             throw new ChatException("Prompt 传入的不是 YiYanOptions!");
         }
         // 转换 YiYanOptions
-        YiYanOptions qianWenOptions = (YiYanOptions) options;
-        // 创建 request
-        List<YiYanChatCompletionRequest.Message> messageList = prompt.getInstructions().stream().map(
-                msg -> new YiYanChatCompletionRequest.Message()
-                        .setRole(msg.getMessageType().getValue())
-                        .setContent(msg.getContent())
-        ).toList();
-        YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList);
-        // 复制 qianWenOptions 属性取 request(这里 options 属性和 request 基本保持一致)
-        // top: 由于遵循 spring-ai规范,支持在构建client的时候传入默认的 chatOptions
-        BeanUtil.copyProperties(qianWenOptions, request);
-        // 设置 stream
-        request.setStream(stream);
-        return request;
+        YiYanOptions useOptions = (YiYanOptions) options;
+        return useOptions;
     }
 }

+ 4 - 6
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanOptions.java

@@ -2,7 +2,6 @@ package cn.iocoder.yudao.framework.ai.chatyiyan;
 
 import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
 import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest;
-import com.fasterxml.jackson.annotation.JsonProperty;
 import lombok.Data;
 import lombok.experimental.Accessors;
 
@@ -40,7 +39,7 @@ public class YiYanOptions implements ChatOptions {
      * (2)默认0.8,取值范围 [0, 1.0]
      * 必填:否
      */
-    private Float top_p;
+    private Float topP;
     /**
      * 通过对已生成的token增加惩罚,减少重复生成的现象。说明:
      * (1)值越大表示惩罚越大
@@ -84,7 +83,7 @@ public class YiYanOptions implements ChatOptions {
      * 指定模型最大输出token数,范围[2, 2048]
      * 必填:否
      */
-    private Integer max_output_tokens;
+    private Integer maxOutputTokens;
     /**
      * 指定响应内容的格式,说明:
      * (1)可选值:
@@ -122,12 +121,12 @@ public class YiYanOptions implements ChatOptions {
 
     @Override
     public Float getTopP() {
-        return top_p;
+        return topP;
     }
 
     @Override
     public void setTopP(Float topP) {
-        this.top_p = topP;
+        this.topP = topP;
     }
 
     // 百度么有 topK
@@ -139,6 +138,5 @@ public class YiYanOptions implements ChatOptions {
 
     @Override
     public void setTopK(Integer topK) {
-
     }
 }

+ 27 - 4
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java

@@ -1,5 +1,8 @@
 package cn.iocoder.yudao.framework.ai.chat;
 
+import cn.iocoder.yudao.framework.ai.chat.messages.Message;
+import cn.iocoder.yudao.framework.ai.chat.messages.SystemMessage;
+import cn.iocoder.yudao.framework.ai.chat.messages.UserMessage;
 import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
 import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatClient;
 import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel;
@@ -9,11 +12,13 @@ import org.junit.Before;
 import org.junit.Test;
 import reactor.core.publisher.Flux;
 
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Scanner;
 
 /**
  * chat 文心一言
- *
+ * <p>
  * author: fansili
  * time: 2024/3/12 20:59
  */
@@ -29,18 +34,36 @@ public class YiYanChatTests {
                 YiYanChatModel.ERNIE4_3_5_8K,
                 86400
         );
-        yiYanChatClient = new YiYanChatClient(yiYanApi, new YiYanOptions().setMax_output_tokens(2048));
+        YiYanOptions yiYanOptions = new YiYanOptions();
+        yiYanOptions.setMaxOutputTokens(2048);
+        yiYanOptions.setTopP(0.6f);
+        yiYanOptions.setTemperature(0.85f);
+        yiYanChatClient = new YiYanChatClient(
+                yiYanApi,
+                yiYanOptions
+        );
     }
 
     @Test
     public void callTest() {
-        ChatResponse call = yiYanChatClient.call(new Prompt("什么编程语言最好?"));
+
+        // tip: 百度的message 有特殊规则(最后一个message为当前请求的信息,前面的message为历史对话信息)
+        // tip: 地址 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
+        List<Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
+        messages.add(new UserMessage("长沙怎么样?"));
+
+        ChatResponse call = yiYanChatClient.call(new Prompt(messages));
         System.err.println(call.getResult());
     }
 
     @Test
     public void streamTest() {
-        Flux<ChatResponse> fluxResponse = yiYanChatClient.stream(new Prompt("用java帮我写一个快排算法?"));
+        List<Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
+        messages.add(new UserMessage("长沙怎么样?"));
+
+        Flux<ChatResponse> fluxResponse = yiYanChatClient.stream(new Prompt(messages));
         fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent()));
         // 阻止退出
         Scanner scanner = new Scanner(System.in);