|
@@ -23,12 +23,11 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatConversationMapper;
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper;
|
|
import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
|
|
import cn.iocoder.yudao.module.ai.service.AiChatConversationService;
|
|
-import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
|
|
|
|
|
|
+import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
|
|
import cn.iocoder.yudao.module.ai.service.AiChatService;
|
|
import cn.iocoder.yudao.module.ai.service.AiChatService;
|
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
|
import lombok.AllArgsConstructor;
|
|
import lombok.AllArgsConstructor;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
-import org.springframework.boot.autoconfigure.http.HttpMessageConverters;
|
|
|
|
import org.springframework.stereotype.Service;
|
|
import org.springframework.stereotype.Service;
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
import reactor.core.publisher.Flux;
|
|
import reactor.core.publisher.Flux;
|
|
@@ -53,12 +52,12 @@ import java.util.stream.Collectors;
|
|
public class AiChatServiceImpl implements AiChatService {
|
|
public class AiChatServiceImpl implements AiChatService {
|
|
|
|
|
|
private final AiChatClientFactory aiChatClientFactory;
|
|
private final AiChatClientFactory aiChatClientFactory;
|
|
|
|
+
|
|
private final AiChatMessageMapper aiChatMessageMapper;
|
|
private final AiChatMessageMapper aiChatMessageMapper;
|
|
private final AiChatConversationMapper aiChatConversationMapper;
|
|
private final AiChatConversationMapper aiChatConversationMapper;
|
|
private final AiChatConversationService chatConversationService;
|
|
private final AiChatConversationService chatConversationService;
|
|
private final AiChatModelService aiChatModalService;
|
|
private final AiChatModelService aiChatModalService;
|
|
- private final AiChatRoleService aiChatRoleService;
|
|
|
|
- private final HttpMessageConverters messageConverters;
|
|
|
|
|
|
+ private final AiChatRoleService chatRoleService;
|
|
|
|
|
|
@Transactional(rollbackFor = Exception.class)
|
|
@Transactional(rollbackFor = Exception.class)
|
|
public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
|
|
public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
|
|
@@ -68,12 +67,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
// 获取对话模型
|
|
// 获取对话模型
|
|
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
|
|
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
|
|
// 获取角色信息
|
|
// 获取角色信息
|
|
- AiChatRoleDO aiChatRoleDO = null;
|
|
|
|
- if (conversation.getRoleId() != null) {
|
|
|
|
- aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
|
|
|
|
- }
|
|
|
|
- // 校验角色是否公开
|
|
|
|
- aiChatRoleService.validateIsPublic(aiChatRoleDO);
|
|
|
|
|
|
+ AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
|
|
// 获取 client 类型
|
|
// 获取 client 类型
|
|
AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
|
|
AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform());
|
|
// 保存 chat message
|
|
// 保存 chat message
|
|
@@ -142,12 +136,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
// 获取对话模型
|
|
// 获取对话模型
|
|
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
|
|
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
|
|
// 获取角色信息
|
|
// 获取角色信息
|
|
- AiChatRoleDO aiChatRoleDO = null;
|
|
|
|
- if (conversation.getRoleId() != null) {
|
|
|
|
- aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
|
|
|
|
- }
|
|
|
|
- // 校验角色是否公开
|
|
|
|
- aiChatRoleService.validateIsPublic(aiChatRoleDO);
|
|
|
|
|
|
+ AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null;
|
|
// 创建 chat 需要的 Prompt
|
|
// 创建 chat 需要的 Prompt
|
|
Prompt prompt = new Prompt(aiChatMessageDO.getContent());
|
|
Prompt prompt = new Prompt(aiChatMessageDO.getContent());
|
|
// 提前创建一个 system message
|
|
// 提前创建一个 system message
|
|
@@ -204,13 +193,6 @@ public class AiChatServiceImpl implements AiChatService {
|
|
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
|
|
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
|
|
// 获取对话模型
|
|
// 获取对话模型
|
|
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
|
|
AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId());
|
|
- // 获取角色信息
|
|
|
|
- AiChatRoleDO aiChatRoleDO = null;
|
|
|
|
- if (conversation.getRoleId() != null) {
|
|
|
|
- aiChatRoleDO = aiChatRoleService.validateExists(conversation.getRoleId());
|
|
|
|
- }
|
|
|
|
- // 校验角色是否公开
|
|
|
|
- aiChatRoleService.validateIsPublic(aiChatRoleDO);
|
|
|
|
AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
|
AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
|
chatModel.getModel(), chatModel.getId(), req.getContent(),
|
|
chatModel.getModel(), chatModel.getId(), req.getContent(),
|
|
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|