|
@@ -9,7 +9,6 @@ import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
|
|
|
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
|
|
|
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
|
|
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
|
|
|
-import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
|
@@ -25,13 +24,12 @@ import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
|
|
|
import cn.iocoder.yudao.module.ai.service.AiChatService;
|
|
|
import lombok.AllArgsConstructor;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
-import org.springframework.http.MediaType;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
import reactor.core.publisher.Flux;
|
|
|
|
|
|
-import java.io.IOException;
|
|
|
import java.util.List;
|
|
|
+import java.util.concurrent.atomic.AtomicInteger;
|
|
|
import java.util.function.Consumer;
|
|
|
|
|
|
/**
|
|
@@ -76,6 +74,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
chatModal.getModel(), chatModal.getId(), req.getContent(),
|
|
|
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
String content = null;
|
|
|
+ int tokens = 0;
|
|
|
try {
|
|
|
// 创建 chat 需要的 Prompt
|
|
|
Prompt prompt = new Prompt(req.getContent());
|
|
@@ -87,6 +86,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum);
|
|
|
ChatResponse call = chatClient.call(prompt);
|
|
|
content = call.getResult().getOutput().getContent();
|
|
|
+ tokens = call.getResults().size();
|
|
|
// 更新 conversation
|
|
|
} catch (Exception e) {
|
|
|
content = ExceptionUtil.getMessage(e);
|
|
@@ -94,7 +94,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
// 保存 chat message
|
|
|
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
|
chatModal.getModel(), chatModal.getId(), content,
|
|
|
- null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
+ tokens, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
}
|
|
|
return new AiChatMessageRespVO().setContent(content);
|
|
|
}
|
|
@@ -123,8 +123,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
return insertChatMessageDO;
|
|
|
}
|
|
|
|
|
|
- @Override
|
|
|
- public void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter) {
|
|
|
+ public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO req) {
|
|
|
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
|
|
// 查询对话
|
|
|
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
|
|
@@ -144,47 +143,43 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
// req.setTopK(req.getTopK());
|
|
|
// req.setTopP(req.getTopP());
|
|
|
// req.setTemperature(req.getTemperature());
|
|
|
- // 保存 chat message
|
|
|
// 保存 chat message
|
|
|
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
|
|
chatModal.getModel(), chatModal.getId(), req.getContent(),
|
|
|
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
-
|
|
|
// 获取 client 类型
|
|
|
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getPlatform());
|
|
|
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
|
|
|
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
|
|
|
-
|
|
|
+ // 转换 flex AiChatMessageRespVO
|
|
|
StringBuffer contentBuffer = new StringBuffer();
|
|
|
- streamResponse.subscribe(
|
|
|
- new Consumer<ChatResponse>() {
|
|
|
- @Override
|
|
|
- public void accept(ChatResponse chatResponse) {
|
|
|
- String content = chatResponse.getResults().get(0).getOutput().getContent();
|
|
|
- try {
|
|
|
- contentBuffer.append(content);
|
|
|
- sseEmitter.send(new AiChatMessageRespVO().setContent(content), MediaType.APPLICATION_JSON);
|
|
|
- } catch (IOException e) {
|
|
|
- log.error("发送异常{}", ExceptionUtil.getMessage(e));
|
|
|
- // 如果不是因为关闭而抛出异常,则重新连接
|
|
|
- sseEmitter.completeWithError(e);
|
|
|
- }
|
|
|
- }
|
|
|
- },
|
|
|
- error -> {
|
|
|
- //
|
|
|
- log.error("subscribe错误 {}", ExceptionUtil.getMessage(error));
|
|
|
- },
|
|
|
- () -> {
|
|
|
- log.info("发送完成!");
|
|
|
- sseEmitter.complete();
|
|
|
- // 保存 chat message
|
|
|
- insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
|
- chatModal.getModel(), chatModal.getId(), contentBuffer.toString(),
|
|
|
- null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
-
|
|
|
+ AtomicInteger tokens = new AtomicInteger(0);
|
|
|
+ return streamResponse.map(res -> {
|
|
|
+ AiChatMessageRespVO aiChatMessageRespVO = new AiChatMessageRespVO();
|
|
|
+ aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent());
|
|
|
+ contentBuffer.append(res.getResult().getOutput().getContent());
|
|
|
+ tokens.incrementAndGet();
|
|
|
+ return aiChatMessageRespVO;
|
|
|
}
|
|
|
- );
|
|
|
+ ).doOnComplete(new Runnable() {
|
|
|
+ @Override
|
|
|
+ public void run() {
|
|
|
+ log.info("发送完成!");
|
|
|
+ // 保存 chat message
|
|
|
+ insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
|
+ chatModal.getModel(), chatModal.getId(), contentBuffer.toString(),
|
|
|
+ tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
+ }
|
|
|
+ }).doOnError(new Consumer<Throwable>() {
|
|
|
+ @Override
|
|
|
+ public void accept(Throwable throwable) {
|
|
|
+ log.error("发送错误 {}!", throwable.getMessage());
|
|
|
+ // 保存 chat message
|
|
|
+ insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
|
|
+ chatModal.getModel(), chatModal.getId(), throwable.getMessage(),
|
|
|
+ tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
|
|
+ }
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -194,7 +189,7 @@ public class AiChatServiceImpl implements AiChatService {
|
|
|
// 获取对话所有 message
|
|
|
List<AiChatMessageDO> aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId);
|
|
|
// 转换 AiChatMessageRespVO
|
|
|
- return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList);
|
|
|
+ return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList);
|
|
|
}
|
|
|
|
|
|
@Override
|