|
@@ -1,6 +1,7 @@
|
|
|
package cn.iocoder.yudao.module.ai.service.write;
|
|
|
|
|
|
import cn.hutool.core.collection.CollUtil;
|
|
|
+import cn.hutool.core.lang.Assert;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
|
|
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
|
@@ -67,19 +68,15 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|
|
|
|
|
@Override
|
|
|
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
|
|
- // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型
|
|
|
- AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
|
|
|
- // TODO @xin:如果有 writeRole,但是没 modeId,是不是也可以用 systemMessage 哈?建议的写法是:先通过 modelId 获取 model。如果 model == null,则 chatModalService.getRequiredDefaultChatModel();如果还是 null,则抛出异常;。。。。。。。。。。。。。。然后,systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理。
|
|
|
- AiChatModelDO model;
|
|
|
- String systemMessage;
|
|
|
- if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
|
|
|
- model = chatModalService.getChatModel(writeRole.getModelId());
|
|
|
- systemMessage = writeRole.getSystemMessage();
|
|
|
- } else {
|
|
|
- model = chatModalService.getRequiredDefaultChatModel();
|
|
|
- systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
|
|
|
- }
|
|
|
- // 1.2 校验平台
|
|
|
+ // 1 获取写作模型。尝试获取写作助手角色,没有则使用默认模型
|
|
|
+ AiChatRoleDO writeRole = CollUtil.getFirst(
|
|
|
+ chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
|
|
|
+ // 1.1 获取写作执行模型
|
|
|
+ AiChatModelDO model = getModel(writeRole);
|
|
|
+ // 1.2 获取角色设定消息
|
|
|
+ String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
|
|
|
+ ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
|
|
|
+ // 1.3 校验平台
|
|
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
|
|
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
|
|
|
|
|
@@ -88,21 +85,11 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|
|
write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
|
|
|
writeMapper.insert(writeDO);
|
|
|
|
|
|
- // 3. 调用大模型,写作生成
|
|
|
- ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
|
|
- // 3.1 角色设定
|
|
|
- // TODO @xin:要不把 90 到 97 这部分,合并到一个方法里。目的是:让这个方法的主干更明确
|
|
|
- List<Message> chatMessages = new ArrayList<>();
|
|
|
- if (StrUtil.isNotBlank(systemMessage)) {
|
|
|
- chatMessages.add(new SystemMessage(systemMessage));
|
|
|
- }
|
|
|
- // 3.2 用户输入
|
|
|
- chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
|
|
|
- // 3.3 构建提示词
|
|
|
- Prompt prompt = new Prompt(chatMessages, chatOptions);
|
|
|
+ // 3.1 构建 Prompt,并进行调用
|
|
|
+ Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
|
|
|
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
|
|
|
|
|
- // 4. 流式返回
|
|
|
+ // 3.2 流式返回
|
|
|
StringBuffer contentBuffer = new StringBuffer();
|
|
|
return streamResponse.map(chunk -> {
|
|
|
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
|
|
@@ -122,7 +109,39 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|
|
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
|
|
|
}
|
|
|
|
|
|
- private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
|
|
|
+ private AiChatModelDO getModel(AiChatRoleDO writeRole) {
|
|
|
+ AiChatModelDO model = null;
|
|
|
+ if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
|
|
|
+ model = chatModalService.getChatModel(writeRole.getModelId());
|
|
|
+ }
|
|
|
+ if (Objects.isNull(model)) {
|
|
|
+ model = chatModalService.getRequiredDefaultChatModel();
|
|
|
+ }
|
|
|
+ Assert.notNull(model, "[AI] 获取不到模型");
|
|
|
+ return model;
|
|
|
+ }
|
|
|
+
|
|
|
+ private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
|
|
|
+ // 1. 构建 message 列表
|
|
|
+ List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
|
|
|
+ // 2. 构建 options 对象
|
|
|
+ AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
|
|
+ ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
|
|
+ return new Prompt(chatMessages, options);
|
|
|
+ }
|
|
|
+
|
|
|
+ private List<Message> buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
|
|
|
+ List<Message> chatMessages = new ArrayList<>();
|
|
|
+ if (StrUtil.isNotBlank(systemMessage)) {
|
|
|
+ // 1.1 角色设定
|
|
|
+ chatMessages.add(new SystemMessage(systemMessage));
|
|
|
+ }
|
|
|
+ // 1.2 用户输入
|
|
|
+ chatMessages.add(new UserMessage(buildUserMessage(generateReqVO)));
|
|
|
+ return chatMessages;
|
|
|
+ }
|
|
|
+
|
|
|
+ private String buildUserMessage(AiWriteGenerateReqVO generateReqVO) {
|
|
|
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
|
|
|
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
|
|
|
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
|