Przeglądaj źródła

【代码新增】AI:接入智谱 GLM-4 模型

YunaiV 8 miesięcy temu
rodzic
commit
6e64ae774e

+ 7 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml

@@ -20,14 +20,20 @@
     <dependencies>
         <dependency>
             <groupId>org.springframework.ai</groupId>
-            <artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
+            <artifactId>spring-ai-zhipuai-spring-boot-starter</artifactId>
             <version>${spring-ai.version}</version>
         </dependency>
+
         <dependency>
             <groupId>org.springframework.ai</groupId>
             <artifactId>spring-ai-openai-spring-boot-starter</artifactId>
             <version>${spring-ai.version}</version>
         </dependency>
+        <dependency>
+            <groupId>org.springframework.ai</groupId>
+            <artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
+            <version>${spring-ai.version}</version>
+        </dependency>
         <dependency>
             <groupId>org.springframework.ai</groupId>
             <artifactId>spring-ai-stability-ai-spring-boot-starter</artifactId>

+ 1 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java

@@ -17,6 +17,7 @@ public enum AiPlatformEnum {
     TONG_YI("TongYi", "通义千问"), // 阿里
     YI_YAN("YiYan", "文心一言"), // 百度
     DEEP_SEEK("DeepSeek", "DeepSeek"), // DeepSeek
+    ZHI_PU("ZhiPu", "智谱"), // 智谱 AI
     XING_HUO("XingHuo", "星火"), // 讯飞
 
     // ========== 国外平台 ==========

+ 21 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -23,8 +23,12 @@ import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
 import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
 import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
 import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
+import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
+import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
+import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
 import org.springframework.ai.chat.model.ChatModel;
 import org.springframework.ai.image.ImageModel;
+import org.springframework.ai.model.function.FunctionCallbackContext;
 import org.springframework.ai.ollama.OllamaChatModel;
 import org.springframework.ai.ollama.api.OllamaApi;
 import org.springframework.ai.openai.OpenAiChatModel;
@@ -36,6 +40,8 @@ import org.springframework.ai.qianfan.QianFanChatModel;
 import org.springframework.ai.qianfan.api.QianFanApi;
 import org.springframework.ai.stabilityai.StabilityAiImageModel;
 import org.springframework.ai.stabilityai.api.StabilityAiApi;
+import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
+import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
 import org.springframework.retry.support.RetryTemplate;
 import org.springframework.web.client.ResponseErrorHandler;
 import org.springframework.web.client.RestClient;
@@ -61,6 +67,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
                     return buildYiYanChatModel(apiKey);
                 case DEEP_SEEK:
                     return buildDeepSeekChatModel(apiKey);
+                case ZHI_PU:
+                    return buildZhiPuChatModel(apiKey, url);
                 case XING_HUO:
                     return buildXingHuoChatModel(apiKey);
                 case OPENAI:
@@ -81,6 +89,10 @@ public class AiModelFactoryImpl implements AiModelFactory {
                 return SpringUtil.getBean(TongYiChatModel.class);
             case YI_YAN:
                 return SpringUtil.getBean(QianFanChatModel.class);
+            case DEEP_SEEK:
+                return SpringUtil.getBean(DeepSeekChatModel.class);
+            case ZHI_PU:
+                return SpringUtil.getBean(ZhiPuAiChatModel.class);
             case XING_HUO:
                 return SpringUtil.getBean(XingHuoChatModel.class);
             case OPENAI:
@@ -175,6 +187,15 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new DeepSeekChatModel(apiKey);
     }
 
+    /**
+     * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
+     */
+    private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
+        url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
+        ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi(url, apiKey);
+        return new ZhiPuAiChatModel(zhiPuAiApi);
+    }
+
     /**
      * 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)}
      */

+ 11 - 8
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java

@@ -10,6 +10,7 @@ import org.springframework.ai.chat.prompt.ChatOptions;
 import org.springframework.ai.ollama.api.OllamaOptions;
 import org.springframework.ai.openai.OpenAiChatOptions;
 import org.springframework.ai.qianfan.QianFanChatOptions;
+import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
 
 /**
  * Spring AI 工具类
@@ -22,18 +23,20 @@ public class AiUtils {
         Float temperatureF = temperature != null ? temperature.floatValue() : null;
         //noinspection EnhancedSwitchMigration
         switch (platform) {
-            case OPENAI:
-                return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
-            case OLLAMA:
-                return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
-            case YI_YAN:
-                return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
-            case XING_HUO:
-                return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
             case TONG_YI:
                 return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
+            case YI_YAN:
+                return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
             case DEEP_SEEK:
                 return DeepSeekChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
+            case ZHI_PU:
+                return ZhiPuAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
+            case XING_HUO:
+                return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
+            case OPENAI:
+                return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
+            case OLLAMA:
+                return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
             default:
                 throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
         }

+ 3 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/DeepSeekChatModelTests.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.framework.ai.chat;
 
 import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
+import org.junit.jupiter.api.Disabled;
 import org.junit.jupiter.api.Test;
 import org.springframework.ai.chat.messages.Message;
 import org.springframework.ai.chat.messages.SystemMessage;
@@ -22,6 +23,7 @@ public class DeepSeekChatModelTests {
     private final DeepSeekChatModel chatModel = new DeepSeekChatModel("sk-e94db327cc7d457d99a8de8810fc6b12");
 
     @Test
+    @Disabled
     public void testCall() {
         // 准备参数
         List<Message> messages = new ArrayList<>();
@@ -35,6 +37,7 @@ public class DeepSeekChatModelTests {
     }
 
     @Test
+    @Disabled
     public void testStream() {
         // 准备参数
         List<Message> messages = new ArrayList<>();

+ 3 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/XingHuoChatModelTests.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.framework.ai.chat;
 
 import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
+import org.junit.jupiter.api.Disabled;
 import org.junit.jupiter.api.Test;
 import org.springframework.ai.chat.messages.Message;
 import org.springframework.ai.chat.messages.SystemMessage;
@@ -24,6 +25,7 @@ public class XingHuoChatModelTests {
             "Y2JiYTIxZjA3MDMxMjNjZjQzYzVmNzdh");
 
     @Test
+    @Disabled
     public void testCall() {
         // 准备参数
         List<Message> messages = new ArrayList<>();
@@ -37,6 +39,7 @@ public class XingHuoChatModelTests {
     }
 
     @Test
+    @Disabled
     public void testStream() {
         // 准备参数
         List<Message> messages = new ArrayList<>();

+ 61 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/ZhiPuAiChatModelTests.java

@@ -0,0 +1,61 @@
+package cn.iocoder.yudao.framework.ai.chat;
+
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.SystemMessage;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
+import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
+import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
+import reactor.core.publisher.Flux;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * {@link ZhiPuAiChatModel} 的集成测试
+ *
+ * @author 芋道源码
+ */
+public class ZhiPuAiChatModelTests {
+
+    private final ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi("32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs");
+    private final ZhiPuAiChatModel chatModel = new ZhiPuAiChatModel(zhiPuAiApi,
+            ZhiPuAiChatOptions.builder().withModel(ZhiPuAiApi.ChatModel.GLM_4.getModelName()).build());
+
+    @Test
+    @Disabled
+    public void testCall() {
+        // 准备参数
+        List<Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+        messages.add(new UserMessage("1 + 1 = ?"));
+
+        // 调用
+        ChatResponse response = chatModel.call(new Prompt(messages));
+        // 打印结果
+        System.out.println(response);
+        System.out.println(response.getResult().getOutput());
+    }
+
+    @Test
+    @Disabled
+    public void testStream() {
+        // 准备参数
+        List<Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+        messages.add(new UserMessage("1 + 1 = ?"));
+
+        // 调用
+        Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
+        // 打印结果
+        flux.doOnNext(response -> {
+//            System.out.println(response);
+            System.out.println(response.getResult().getOutput());
+        }).then().block();
+    }
+
+}

+ 7 - 5
yudao-server/src/main/resources/application.yaml

@@ -156,13 +156,15 @@ spring:
     qianfan: # 文心一言
       api-key: x0cuLZ7XsaTCU08vuJWO87Lg
       secret-key: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
+    zhipuai: # 智谱 AI
+      api-key: 32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs
+    openai:
+      api-key: sk-yzKea6d8e8212c3bdd99f9f44ced1cae37c097e5aa3BTS7z
+      base-url: https://api.gptsapi.net
     ollama:
       base-url: http://127.0.0.1:11434
       chat:
         model: llama3
-    openai:
-      api-key: sk-yzKea6d8e8212c3bdd99f9f44ced1cae37c097e5aa3BTS7z
-      base-url: https://api.gptsapi.net
     stabilityai:
       api-key: sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx
   cloud:
@@ -173,11 +175,11 @@ spring:
 
 yudao:
   ai:
-    deep-seek:
+    deep-seek: # DeepSeek
       enable: true
       api-key: sk-e94db327cc7d457d99a8de8810fc6b12
       model: deepseek-chat
-    xinghuo:
+    xinghuo: # 讯飞星火
       enable: true
       appId: 13c8cca6
       appKey: cb6415c19d6162cda07b47316fcb0416