Parcourir la source

【代码优化】AI:音乐生成

YunaiV il y a 10 mois
Parent
commit
23baaff84d

+ 2 - 3
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/music/AiMusicStatusEnum.java

@@ -12,9 +12,8 @@ import lombok.Getter;
 @Getter
 public enum AiMusicStatusEnum {
 
-    // @xin 文档中无失败这个返回值
-    STREAMING(10, "进行中"),
-    COMPLETE(20, "完成");
+    IN_PROGRESS(10, "进行中"),
+    SUCCESS(20, "已完成");
 
     /**
      * 状态

+ 11 - 10
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/music/vo/AiSunoGenerateReqVO.java

@@ -7,11 +7,20 @@ import lombok.Data;
 
 import java.util.List;
 
-@Schema(description = "管理后台 - 音乐生成 Request VO")
+@Schema(description = "管理后台 - AI 音乐生成 Request VO")
 @Data
 public class AiSunoGenerateReqVO {
 
-    @Schema(description = "用于生成音乐音频的提示", requiredMode = Schema.RequiredMode.REQUIRED, example = "创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。")
+    @Schema(description = "平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "Suno")
+    @NotBlank(message = "平台不能为空")
+    private String platform; // 参见 AiPlatformEnum 枚举
+
+    @Schema(description = "生成模式", requiredMode = Schema.RequiredMode.REQUIRED, example = "2")
+    @NotNull(message = "生成模式不能为空")
+    private Integer generateMode; // 参见 AiMusicGenerateEnum 枚举
+
+    @Schema(description = "用于生成音乐音频的提示", requiredMode = Schema.RequiredMode.REQUIRED,
+            example = "创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。")
     private String prompt;
 
     @Schema(description = "是否纯音乐", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "true")
@@ -26,12 +35,4 @@ public class AiSunoGenerateReqVO {
     @Schema(description = "音乐/歌曲名称", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "夜空中最亮的星")
     private String title;
 
-    @Schema(description = "平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "Suno")
-    @NotBlank(message = "平台不能为空")
-    private String platform; // 参见 AiPlatformEnum 枚举
-
-    @Schema(description = "生成模式", requiredMode = Schema.RequiredMode.REQUIRED, example = "2")
-    @NotNull(message = "生成模式不能为空")
-    private Integer generateMode; // 参见 AiMusicGenerateEnum 枚举
-
 }

+ 19 - 18
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/music/AiMusicDO.java

@@ -1,6 +1,8 @@
 package cn.iocoder.yudao.module.ai.dal.dataobject.music;
 
+import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateModeEnum;
 import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum;
 import com.baomidou.mybatisplus.annotation.IdType;
 import com.baomidou.mybatisplus.annotation.TableField;
@@ -38,21 +40,19 @@ public class AiMusicDO extends BaseDO {
      */
     private String title;
 
-    /**
-     * 图片地址
-     */
-    private String imageUrl;
-
     /**
      * 歌词
      */
     private String lyric;
 
+    /**
+     * 图片地址
+     */
+    private String imageUrl;
     /**
      * 音频地址
      */
     private String audioUrl;
-
     /**
      * 视频地址
      */
@@ -65,6 +65,13 @@ public class AiMusicDO extends BaseDO {
      */
     private Integer status;
 
+    /**
+     * 生成模式
+     *
+     * 枚举 {@link AiMusicGenerateModeEnum}
+     */
+    private Integer generateMode;
+
     /**
      * 描述词
      */
@@ -74,28 +81,17 @@ public class AiMusicDO extends BaseDO {
      */
     private String prompt;
 
-    /**
-     * 生成模式
-     */
-    private Integer generateMode;
-
     /**
      * 平台
      * <p>
-     * 枚举 {@link cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum}
+     * 枚举 {@link AiPlatformEnum}
      */
     private String platform;
-
     /**
      * 模型
      */
     private String model;
 
-    /**
-     * 错误信息
-     */
-    private String errorMessage;
-
     /**
      * 音乐风格标签
      */
@@ -107,4 +103,9 @@ public class AiMusicDO extends BaseDO {
      */
     private String taskId;
 
+    /**
+     * 错误信息
+     */
+    private String errorMessage;
+
 }

+ 1 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/music/AiMusicMapper.java

@@ -1,7 +1,6 @@
 package cn.iocoder.yudao.module.ai.dal.mysql.music;
 
 import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
-import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
 import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
 import org.apache.ibatis.annotations.Mapper;
 
@@ -16,8 +15,7 @@ import java.util.List;
 public interface AiMusicMapper extends BaseMapperX<AiMusicDO> {
 
     default List<AiMusicDO> selectListByStatus(Integer status) {
-        return selectList(new LambdaQueryWrapperX<AiMusicDO>()
-                .eq(AiMusicDO::getStatus, status));
+        return selectList(AiMusicDO::getStatus, status);
     }
 
 }

+ 3 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicService.java

@@ -14,10 +14,11 @@ public interface AiMusicService {
     /**
      * 音乐生成
      *
+     * @param userId 用户编号
      * @param reqVO 请求参数
      * @return 生成的音乐ID
      */
-    List<Long> generateMusic(AiSunoGenerateReqVO reqVO);
+    List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO);
 
     /**
      * 同步音乐任务
@@ -25,4 +26,5 @@ public interface AiMusicService {
      * @return 同步数量
      */
     Integer syncMusic();
+
 }

+ 26 - 27
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java

@@ -4,7 +4,6 @@ import cn.hutool.core.collection.CollUtil;
 import cn.hutool.core.text.StrPool;
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
-import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiSunoGenerateReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.music.AiMusicMapper;
@@ -16,7 +15,8 @@ import org.springframework.stereotype.Service;
 
 import java.util.*;
 
-import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap;
 
 /**
  * AI 音乐 Service 实现类
@@ -34,54 +34,53 @@ public class AiMusicServiceImpl implements AiMusicService {
     private AiMusicMapper musicMapper;
 
     @Override
-    public List<Long> generateMusic(AiSunoGenerateReqVO reqVO) {
+    public List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO) {
+        // 1. 调用 Suno 生成音乐
         List<SunoApi.MusicData> musicDataList;
         if (Objects.equals(AiMusicGenerateModeEnum.LYRIC.getMode(), reqVO.getGenerateMode())) {
             // 1.1 歌词模式
-            SunoApi.MusicGenerateRequest sunoReq = new SunoApi.MusicGenerateRequest(
+            SunoApi.MusicGenerateRequest generateRequest = new SunoApi.MusicGenerateRequest(
                     reqVO.getPrompt(), reqVO.getModelVersion(), CollUtil.join(reqVO.getTags(), StrPool.COMMA), reqVO.getTitle());
-            musicDataList = sunoApi.customGenerate(sunoReq);
+            musicDataList = sunoApi.customGenerate(generateRequest);
         } else if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) {
             // 1.2 描述模式
-            SunoApi.MusicGenerateRequest sunoReq = new SunoApi.MusicGenerateRequest(
+            SunoApi.MusicGenerateRequest generateRequest = new SunoApi.MusicGenerateRequest(
                     reqVO.getPrompt(), reqVO.getModelVersion(), reqVO.getMakeInstrumental());
-            musicDataList = sunoApi.generate(sunoReq);
+            musicDataList = sunoApi.generate(generateRequest);
         } else {
             throw new IllegalArgumentException(StrUtil.format("未知生成模式({})", reqVO));
         }
+
         // 2. 插入数据库
         if (CollUtil.isEmpty(musicDataList)) {
-
             return Collections.emptyList();
         }
-        List<AiMusicDO> aiMusicDOList = CollectionUtils.convertList(buildMusicDOList(musicDataList), musicDO ->
-                musicDO.setUserId(getLoginUserId())
-                        .setGenerateMode(reqVO.getGenerateMode())
-                        .setPlatform(reqVO.getPlatform()
-                        ));
-        musicMapper.insertBatch(aiMusicDOList);
-        return CollectionUtils.convertList(aiMusicDOList, AiMusicDO::getId);
+        List<AiMusicDO> musicList = buildMusicDOList(musicDataList);
+        musicList.forEach(music -> music.setUserId(userId).setPlatform(music.getPlatform()).setGenerateMode(reqVO.getGenerateMode()));
+        musicMapper.insertBatch(musicList);
+        return convertList(musicList, AiMusicDO::getId);
     }
 
     @Override
     public Integer syncMusic() {
-        List<AiMusicDO> streamingTask = musicMapper.selectListByStatus(AiMusicStatusEnum.STREAMING.getStatus());
+        List<AiMusicDO> streamingTask = musicMapper.selectListByStatus(AiMusicStatusEnum.IN_PROGRESS.getStatus());
         if (CollUtil.isEmpty(streamingTask)) {
             return 0;
         }
         log.info("[syncMusic][Suno 开始同步, 共 ({}) 个任务]", streamingTask.size());
+
         // GET 请求,为避免参数过长,分批次处理
-        CollUtil.split(streamingTask, 36).forEach(chunk -> {
-            Map<String, Long> taskIdMap = CollectionUtils.convertMap(chunk, AiMusicDO::getTaskId, AiMusicDO::getId);
+        CollUtil.split(streamingTask, 36).forEach(chunkList -> {
+            Map<String, Long> taskIdMap = convertMap(chunkList, AiMusicDO::getTaskId, AiMusicDO::getId);
             List<SunoApi.MusicData> musicTaskList = sunoApi.getMusicList(new ArrayList<>(taskIdMap.keySet()));
             if (CollUtil.isEmpty(musicTaskList)) {
                 log.warn("Suno 任务同步失败, 任务ID: [{}]", taskIdMap.keySet());
                 return;
             }
-            List<AiMusicDO> aiMusicDOS = buildMusicDOList(musicTaskList);
-            //回填id
-            aiMusicDOS.forEach(aiMusicDO -> aiMusicDO.setId(taskIdMap.get(aiMusicDO.getTaskId())));
-            musicMapper.updateBatch(aiMusicDOS);
+            // 更新进度
+            List<AiMusicDO> updateMusicList = buildMusicDOList(musicTaskList);
+            updateMusicList.forEach(music -> music.setId(taskIdMap.get(music.getTaskId())));
+            musicMapper.updateBatch(updateMusicList);
         });
         return streamingTask.size();
     }
@@ -89,16 +88,16 @@ public class AiMusicServiceImpl implements AiMusicService {
     /**
      * 构建 AiMusicDO 集合
      *
-     * @param musicTaskList suno 音乐任务列表
+     * @param musicList suno 音乐任务列表
      * @return AiMusicDO 集合
      */
-    private static List<AiMusicDO> buildMusicDOList(List<SunoApi.MusicData> musicTaskList) {
-        return CollectionUtils.convertList(musicTaskList, musicData -> new AiMusicDO()
-                .setTaskId(musicData.id())
+    private static List<AiMusicDO> buildMusicDOList(List<SunoApi.MusicData> musicList) {
+        return convertList(musicList, musicData -> new AiMusicDO()
+                .setTaskId(musicData.id()).setModel(musicData.modelName())
                 .setPrompt(musicData.prompt()).setGptDescriptionPrompt(musicData.gptDescriptionPrompt())
                 .setAudioUrl(musicData.audioUrl()).setVideoUrl(musicData.videoUrl()).setImageUrl(musicData.imageUrl())
                 .setTitle(musicData.title()).setLyric(musicData.lyric()).setTags(StrUtil.split(musicData.tags(), StrPool.COMMA))
-                .setModel(musicData.modelName()).setStatus(Objects.equals("complete", musicData.status()) ? AiMusicStatusEnum.COMPLETE.getStatus() : AiMusicStatusEnum.STREAMING.getStatus()));
+                .setStatus(Objects.equals("complete", musicData.status()) ? AiMusicStatusEnum.SUCCESS.getStatus() : AiMusicStatusEnum.IN_PROGRESS.getStatus()));
 
     }
 }