|
@@ -4,35 +4,28 @@ import cn.hutool.core.collection.CollUtil;
|
|
|
import cn.hutool.core.util.ObjUtil;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
|
|
-import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
|
|
-import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
|
|
|
+import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
|
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
|
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
|
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
|
|
|
-import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
|
|
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
|
|
|
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
|
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
|
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
|
|
-import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
|
|
|
-import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
|
|
import jakarta.annotation.Resource;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.springframework.ai.chat.messages.*;
|
|
|
+import org.springframework.ai.chat.model.ChatModel;
|
|
|
import org.springframework.ai.chat.model.ChatResponse;
|
|
|
import org.springframework.ai.chat.model.StreamingChatModel;
|
|
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
-import org.springframework.ai.ollama.api.OllamaOptions;
|
|
|
-import org.springframework.ai.openai.OpenAiChatOptions;
|
|
|
-import org.springframework.ai.qianfan.QianFanChatOptions;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
import reactor.core.publisher.Flux;
|
|
@@ -64,47 +57,37 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
@Resource
|
|
|
private AiChatModelService chatModalService;
|
|
|
@Resource
|
|
|
- private AiChatRoleService chatRoleService;
|
|
|
- @Resource
|
|
|
private AiApiKeyService apiKeyService;
|
|
|
|
|
|
@Transactional(rollbackFor = Exception.class)
|
|
|
- public AiChatMessageRespVO sendMessage(AiChatMessageSendReqVO req) {
|
|
|
- return null; // TODO 芋艿:一起改
|
|
|
-// Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
|
|
-// // 查询对话
|
|
|
-// AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId());
|
|
|
-// // 获取对话模型
|
|
|
-// AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId());
|
|
|
-// // 获取角色信息
|
|
|
-// AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
|
|
|
-// // 获取 client 类型
|
|
|
-// AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
|
|
|
-// // 保存 chat message
|
|
|
-// createChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
|
|
-// chatModel.getModel(), chatModel.getId(), req.getContent());
|
|
|
-// String content = null;
|
|
|
-// int tokens = 0;
|
|
|
-// try {
|
|
|
-// // 创建 chat 需要的 Prompt
|
|
|
-// Prompt prompt = new Prompt(req.getContent());
|
|
|
-// // TODO @芋艿 @范 看要不要支持这些
|
|
|
-//// req.setTopK(req.getTopK());
|
|
|
-//// req.setTopP(req.getTopP());
|
|
|
-//// req.setTemperature(req.getTemperature());
|
|
|
-// // 发送 call 调用
|
|
|
-// ChatClient chatClient = chatClientFactory.getChatClient(platformEnum);
|
|
|
-// ChatResponse call = chatClient.call(prompt);
|
|
|
-// content = call.getResult().getOutput().getContent();
|
|
|
-// // 更新 conversation
|
|
|
-// } catch (Exception e) {
|
|
|
-// content = ExceptionUtil.getMessage(e);
|
|
|
-// } finally {
|
|
|
-// // 保存 chat message
|
|
|
-// createChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
|
-// chatModel.getModel(), chatModel.getId(), content);
|
|
|
-// }
|
|
|
-// return new AiChatMessageRespVO().setContent(content);
|
|
|
+ public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
|
|
|
+ // 1.1 校验对话存在
|
|
|
+ AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
|
|
|
+ if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
|
|
|
+ throw exception(CHAT_CONVERSATION_NOT_EXISTS);
|
|
|
+ }
|
|
|
+ List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
|
|
|
+ // 1.2 校验模型
|
|
|
+ AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
|
|
|
+ ChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
|
|
|
+
|
|
|
+ // 2. 插入 user 发送消息
|
|
|
+ AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
|
|
|
+ userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext());
|
|
|
+
|
|
|
+ // 3.1 插入 assistant 接收消息
|
|
|
+ AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
|
|
+ userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
|
|
+
|
|
|
+ // 3.2 创建 chat 需要的 Prompt
|
|
|
+ Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
|
|
|
+ ChatResponse chatResponse = chatClient.call(prompt);
|
|
|
+
|
|
|
+ // 3.3 段式返回
|
|
|
+ String newContent = chatResponse.getResult().getOutput().getContent();
|
|
|
+ chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
|
|
|
+ return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
|
|
|
+ .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -112,14 +95,12 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
// 1.1 校验对话存在
|
|
|
AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
|
|
|
if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
|
|
|
- throw exception(CHAT_CONVERSATION_NOT_EXISTS); // TODO 芋艿:异常情况的对接;
|
|
|
+ throw exception(CHAT_CONVERSATION_NOT_EXISTS);
|
|
|
}
|
|
|
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
|
|
|
// 1.2 校验模型
|
|
|
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
|
|
|
- StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
|
|
|
- // 1.3 获取用户头像、角色头像
|
|
|
- AiChatRoleDO role = conversation.getRoleId() != null ? chatRoleService.getChatRole(conversation.getRoleId()) : null;
|
|
|
+ StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
|
|
|
|
|
|
// 2. 插入 user 发送消息
|
|
|
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
|
|
@@ -149,9 +130,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
// TODO @芋艿:失败的情况下,要不要删除消息
|
|
|
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
|
|
|
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage()));
|
|
|
- }).onErrorResume(error -> {
|
|
|
- return Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR));
|
|
|
- });
|
|
|
+ }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
|
|
|
}
|
|
|
|
|
|
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
|
|
@@ -164,46 +143,17 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
}
|
|
|
// 1.2 history message 历史消息
|
|
|
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
|
|
|
- contextMessages.forEach(message -> {
|
|
|
- // TODO @芋艿:看看有没优化空间
|
|
|
- if (MessageType.USER.getValue().equals(message.getType())) {
|
|
|
- chatMessages.add(new UserMessage(message.getContent()));
|
|
|
- } else {
|
|
|
- chatMessages.add(new AssistantMessage(message.getContent()));
|
|
|
- }
|
|
|
- });
|
|
|
+ contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
|
|
|
// 1.3 user message 新发送消息
|
|
|
chatMessages.add(new UserMessage(sendReqVO.getContent()));
|
|
|
|
|
|
// 2. 构建 ChatOptions 对象
|
|
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
|
|
- ChatOptions chatOptions = buildChatOptions(platform, model.getModel(),
|
|
|
+ ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
|
|
|
conversation.getTemperature(), conversation.getMaxTokens());
|
|
|
return new Prompt(chatMessages, chatOptions);
|
|
|
}
|
|
|
|
|
|
- private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
|
|
- Float temperatureF = temperature != null ? temperature.floatValue() : null;
|
|
|
- //noinspection EnhancedSwitchMigration
|
|
|
- switch (platform) {
|
|
|
- case OPENAI:
|
|
|
- return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
|
|
- case OLLAMA:
|
|
|
- return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
|
|
- case YI_YAN:
|
|
|
- // TODO 芋艿:貌似 model 只要一设置,就报错
|
|
|
-// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
|
|
- return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
|
|
- case XING_HUO:
|
|
|
- return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
|
|
|
- .setMaxTokens(maxTokens);
|
|
|
- case QIAN_WEN:
|
|
|
- return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
|
|
|
- default:
|
|
|
- throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
/**
|
|
|
* 从历史消息中,获得倒序的 n 组消息作为消息上下文
|
|
|
*
|