浏览代码

【新增】AI 知识库:文档向量化 demo

xiaoxin 8 月之前
父节点
当前提交
2a984504d9

+ 16 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/DocService.java

@@ -0,0 +1,16 @@
+package cn.iocoder.yudao.module.ai.service.knowledge;
+
+/**
+ * AI 知识库 Service 接口
+ *
+ * @author xiaoxin
+ */
+public interface DocService {
+
+
+    /**
+     * 向量化文档
+     */
+    void embeddingDoc();
+
+}

+ 44 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/DocServiceImpl.java

@@ -0,0 +1,44 @@
+package cn.iocoder.yudao.module.ai.service.knowledge;
+
+import jakarta.annotation.Resource;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.reader.TextReader;
+import org.springframework.ai.transformer.splitter.TokenTextSplitter;
+import org.springframework.ai.vectorstore.RedisVectorStore;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Service;
+
+import java.util.List;
+
+/**
+ * AI 知识库 Service 实现类
+ *
+ * @author xiaoxin
+ */
+@Service
+@Slf4j
+public class DocServiceImpl implements DocService {
+
+    @Resource
+    RedisVectorStore vectorStore;
+    @Resource
+    TokenTextSplitter tokenTextSplitter;
+
+    // TODO @xin 临时测试用,后续删
+    @Value("classpath:/webapp/test/Fel.pdf")
+    private org.springframework.core.io.Resource data;
+
+
+    @Override
+    public void embeddingDoc() {
+        // 读取文件
+        org.springframework.core.io.Resource file = data;
+        TextReader loader = new TextReader(file);
+        List<Document> documents = loader.get();
+        // 文档分段
+        List<Document> segments = tokenTextSplitter.apply(documents);
+        // 向量化并存储
+        vectorStore.add(segments);
+    }
+}

+ 16 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml

@@ -39,6 +39,22 @@
             <artifactId>spring-ai-stability-ai-spring-boot-starter</artifactId>
             <version>${spring-ai.version}</version>
         </dependency>
+        <dependency>
+            <groupId>org.springframework.ai</groupId>
+            <artifactId>spring-ai-transformers-spring-boot-starter</artifactId>
+            <version>${spring-ai.version}</version>
+        </dependency>
+        <dependency>
+            <groupId>org.springframework.ai</groupId>
+            <artifactId>spring-ai-redis-store</artifactId>
+            <version>${spring-ai.version}</version>
+        </dependency>
+        <dependency>
+            <groupId>org.springframework.data</groupId>
+            <artifactId>spring-data-redis</artifactId>
+            <optional>true</optional>
+        </dependency>
+
 
         <dependency>
             <groupId>cn.iocoder.boot</groupId>

+ 59 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java

@@ -0,0 +1,59 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.autoconfigure.vectorstore.redis;
+
+import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.ai.vectorstore.RedisVectorStore;
+import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig;
+import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
+import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration;
+import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.context.annotation.Bean;
+import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
+import redis.clients.jedis.JedisPooled;
+
+/**
+ * TODO @xin 先拿 spring-ai 最新代码覆盖,1.0.0-M1 跟 redis 自动配置会冲突
+ *
+ * @author Christian Tzolov
+ * @author Eddú Meléndez
+ */
+@AutoConfiguration(after = RedisAutoConfiguration.class)
+@ConditionalOnClass({JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class})
+//@ConditionalOnBean(JedisConnectionFactory.class)
+@EnableConfigurationProperties(RedisVectorStoreProperties.class)
+public class RedisVectorStoreAutoConfiguration {
+
+
+
+    @Bean
+    @ConditionalOnMissingBean
+    public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorStoreProperties properties,
+                                        JedisConnectionFactory jedisConnectionFactory) {
+
+        var config = RedisVectorStoreConfig.builder()
+                .withIndexName(properties.getIndex())
+                .withPrefix(properties.getPrefix())
+                .build();
+
+        return new RedisVectorStore(config, embeddingModel,
+                new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()),
+                properties.isInitializeSchema());
+    }
+
+}

+ 456 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java

@@ -0,0 +1,456 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.vectorstore;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
+import org.springframework.beans.factory.InitializingBean;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+import redis.clients.jedis.JedisPooled;
+import redis.clients.jedis.Pipeline;
+import redis.clients.jedis.json.Path2;
+import redis.clients.jedis.search.*;
+import redis.clients.jedis.search.Schema.FieldType;
+import redis.clients.jedis.search.schemafields.*;
+import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm;
+
+import java.text.MessageFormat;
+import java.util.*;
+import java.util.function.Function;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+
+/**
+ * The RedisVectorStore is for managing and querying vector data in a Redis database. It
+ * offers functionalities like adding, deleting, and performing similarity searches on
+ * documents.
+ *
+ * The store utilizes RedisJSON and RedisSearch to handle JSON documents and to index and
+ * search vector data. It supports various vector algorithms (e.g., FLAT, HSNW) for
+ * efficient similarity searches. Additionally, it allows for custom metadata fields in
+ * the documents to be stored alongside the vector and content data.
+ *
+ * This class requires a RedisVectorStoreConfig configuration object for initialization,
+ * which includes settings like Redis URI, index name, field names, and vector algorithms.
+ * It also requires an EmbeddingModel to convert documents into embeddings before storing
+ * them.
+ *
+ * @author Julien Ruaux
+ * @author Christian Tzolov
+ * @author Eddú Meléndez
+ * @see VectorStore
+ * @see RedisVectorStoreConfig
+ * @see EmbeddingModel
+ */
+public class RedisVectorStore implements VectorStore, InitializingBean {
+
+    public enum Algorithm {
+
+        FLAT, HSNW
+
+    }
+
+    public record MetadataField(String name, FieldType fieldType) {
+
+        public static MetadataField text(String name) {
+            return new MetadataField(name, FieldType.TEXT);
+        }
+
+        public static MetadataField numeric(String name) {
+            return new MetadataField(name, FieldType.NUMERIC);
+        }
+
+        public static MetadataField tag(String name) {
+            return new MetadataField(name, FieldType.TAG);
+        }
+
+    }
+
+    /**
+     * Configuration for the Redis vector store.
+     */
+    public static final class RedisVectorStoreConfig {
+
+        private final String indexName;
+
+        private final String prefix;
+
+        private final String contentFieldName;
+
+        private final String embeddingFieldName;
+
+        private final Algorithm vectorAlgorithm;
+
+        private final List<MetadataField> metadataFields;
+
+        private RedisVectorStoreConfig() {
+            this(builder());
+        }
+
+        private RedisVectorStoreConfig(Builder builder) {
+            this.indexName = builder.indexName;
+            this.prefix = builder.prefix;
+            this.contentFieldName = builder.contentFieldName;
+            this.embeddingFieldName = builder.embeddingFieldName;
+            this.vectorAlgorithm = builder.vectorAlgorithm;
+            this.metadataFields = builder.metadataFields;
+        }
+
+        /**
+         * Start building a new configuration.
+         * @return The entry point for creating a new configuration.
+         */
+        public static Builder builder() {
+
+            return new Builder();
+        }
+
+        /**
+         * {@return the default config}
+         */
+        public static RedisVectorStoreConfig defaultConfig() {
+
+            return builder().build();
+        }
+
+        public static class Builder {
+
+            private String indexName = DEFAULT_INDEX_NAME;
+
+            private String prefix = DEFAULT_PREFIX;
+
+            private String contentFieldName = DEFAULT_CONTENT_FIELD_NAME;
+
+            private String embeddingFieldName = DEFAULT_EMBEDDING_FIELD_NAME;
+
+            private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM;
+
+            private List<MetadataField> metadataFields = new ArrayList<>();
+
+            private Builder() {
+            }
+
+            /**
+             * Configures the Redis index name to use.
+             * @param name the index name to use
+             * @return this builder
+             */
+            public Builder withIndexName(String name) {
+                this.indexName = name;
+                return this;
+            }
+
+            /**
+             * Configures the Redis key prefix to use (default: "embedding:").
+             * @param prefix the prefix to use
+             * @return this builder
+             */
+            public Builder withPrefix(String prefix) {
+                this.prefix = prefix;
+                return this;
+            }
+
+            /**
+             * Configures the Redis content field name to use.
+             * @param name the content field name to use
+             * @return this builder
+             */
+            public Builder withContentFieldName(String name) {
+                this.contentFieldName = name;
+                return this;
+            }
+
+            /**
+             * Configures the Redis embedding field name to use.
+             * @param name the embedding field name to use
+             * @return this builder
+             */
+            public Builder withEmbeddingFieldName(String name) {
+                this.embeddingFieldName = name;
+                return this;
+            }
+
+            /**
+             * Configures the Redis vector algorithmto use.
+             * @param algorithm the vector algorithm to use
+             * @return this builder
+             */
+            public Builder withVectorAlgorithm(Algorithm algorithm) {
+                this.vectorAlgorithm = algorithm;
+                return this;
+            }
+
+            public Builder withMetadataFields(MetadataField... fields) {
+                return withMetadataFields(Arrays.asList(fields));
+            }
+
+            public Builder withMetadataFields(List<MetadataField> fields) {
+                this.metadataFields = fields;
+                return this;
+            }
+
+            /**
+             * {@return the immutable configuration}
+             */
+            public RedisVectorStoreConfig build() {
+
+                return new RedisVectorStoreConfig(this);
+            }
+
+        }
+
+    }
+
+    private final boolean initializeSchema;
+
+    public static final String DEFAULT_INDEX_NAME = "spring-ai-index";
+
+    public static final String DEFAULT_CONTENT_FIELD_NAME = "content";
+
+    public static final String DEFAULT_EMBEDDING_FIELD_NAME = "embedding";
+
+    public static final String DEFAULT_PREFIX = "embedding:";
+
+    public static final Algorithm DEFAULT_VECTOR_ALGORITHM = Algorithm.HSNW;
+
+    private static final String QUERY_FORMAT = "%s=>[KNN %s @%s $%s AS %s]";
+
+    private static final Path2 JSON_SET_PATH = Path2.of("$");
+
+    private static final String JSON_PATH_PREFIX = "$.";
+
+    private static final Logger logger = LoggerFactory.getLogger(RedisVectorStore.class);
+
+    private static final Predicate<Object> RESPONSE_OK = Predicate.isEqual("OK");
+
+    private static final Predicate<Object> RESPONSE_DEL_OK = Predicate.isEqual(1l);
+
+    private static final String VECTOR_TYPE_FLOAT32 = "FLOAT32";
+
+    private static final String EMBEDDING_PARAM_NAME = "BLOB";
+
+    public static final String DISTANCE_FIELD_NAME = "vector_score";
+
+    private static final String DEFAULT_DISTANCE_METRIC = "COSINE";
+
+    private final JedisPooled jedis;
+
+    private final EmbeddingModel embeddingModel;
+
+    private final RedisVectorStoreConfig config;
+
+    private FilterExpressionConverter filterExpressionConverter;
+
+    public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis,
+                            boolean initializeSchema) {
+
+        Assert.notNull(config, "Config must not be null");
+        Assert.notNull(embeddingModel, "Embedding model must not be null");
+        this.initializeSchema = initializeSchema;
+
+        this.jedis = jedis;
+        this.embeddingModel = embeddingModel;
+        this.config = config;
+        this.filterExpressionConverter = new RedisFilterExpressionConverter(this.config.metadataFields);
+    }
+
+    public JedisPooled getJedis() {
+        return this.jedis;
+    }
+
+    @Override
+    public void add(List<Document> documents) {
+        try (Pipeline pipeline = this.jedis.pipelined()) {
+            for (Document document : documents) {
+                var embedding = this.embeddingModel.embed(document);
+                document.setEmbedding(embedding);
+
+                var fields = new HashMap<String, Object>();
+                fields.put(this.config.embeddingFieldName, embedding);
+                fields.put(this.config.contentFieldName, document.getContent());
+                fields.putAll(document.getMetadata());
+                pipeline.jsonSetWithEscape(key(document.getId()), JSON_SET_PATH, fields);
+            }
+            List<Object> responses = pipeline.syncAndReturnAll();
+            Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_OK)).findAny();
+            if (errResponse.isPresent()) {
+                String message = MessageFormat.format("Could not add document: {0}", errResponse.get());
+                if (logger.isErrorEnabled()) {
+                    logger.error(message);
+                }
+                throw new RuntimeException(message);
+            }
+        }
+    }
+
+    private String key(String id) {
+        return this.config.prefix + id;
+    }
+
+    @Override
+    public Optional<Boolean> delete(List<String> idList) {
+        try (Pipeline pipeline = this.jedis.pipelined()) {
+            for (String id : idList) {
+                pipeline.jsonDel(key(id));
+            }
+            List<Object> responses = pipeline.syncAndReturnAll();
+            Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_DEL_OK)).findAny();
+            if (errResponse.isPresent()) {
+                if (logger.isErrorEnabled()) {
+                    logger.error("Could not delete document: {}", errResponse.get());
+                }
+                return Optional.of(false);
+            }
+            return Optional.of(true);
+        }
+    }
+
+    @Override
+    public List<Document> similaritySearch(SearchRequest request) {
+
+        Assert.isTrue(request.getTopK() > 0, "The number of documents to returned must be greater than zero");
+        Assert.isTrue(request.getSimilarityThreshold() >= 0 && request.getSimilarityThreshold() <= 1,
+                "The similarity score is bounded between 0 and 1; least to most similar respectively.");
+
+        String filter = nativeExpressionFilter(request);
+
+        String queryString = String.format(QUERY_FORMAT, filter, request.getTopK(), this.config.embeddingFieldName,
+                EMBEDDING_PARAM_NAME, DISTANCE_FIELD_NAME);
+
+        List<String> returnFields = new ArrayList<>();
+        this.config.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add);
+        returnFields.add(this.config.embeddingFieldName);
+        returnFields.add(this.config.contentFieldName);
+        returnFields.add(DISTANCE_FIELD_NAME);
+        var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery()));
+        Query query = new Query(queryString).addParam(EMBEDDING_PARAM_NAME, RediSearchUtil.toByteArray(embedding))
+                .returnFields(returnFields.toArray(new String[0]))
+                .setSortBy(DISTANCE_FIELD_NAME, true)
+                .dialect(2);
+
+        SearchResult result = this.jedis.ftSearch(this.config.indexName, query);
+        return result.getDocuments()
+                .stream()
+                .filter(d -> similarityScore(d) >= request.getSimilarityThreshold())
+                .map(this::toDocument)
+                .toList();
+    }
+
+    private Document toDocument(redis.clients.jedis.search.Document doc) {
+        var id = doc.getId().substring(this.config.prefix.length());
+        var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName)
+                : null;
+        Map<String, Object> metadata = this.config.metadataFields.stream()
+                .map(MetadataField::name)
+                .filter(doc::hasProperty)
+                .collect(Collectors.toMap(Function.identity(), doc::getString));
+        metadata.put(DISTANCE_FIELD_NAME, 1 - similarityScore(doc));
+        return new Document(id, content, metadata);
+    }
+
+    private float similarityScore(redis.clients.jedis.search.Document doc) {
+        return (2 - Float.parseFloat(doc.getString(DISTANCE_FIELD_NAME))) / 2;
+    }
+
+    private String nativeExpressionFilter(SearchRequest request) {
+        if (request.getFilterExpression() == null) {
+            return "*";
+        }
+        return "(" + this.filterExpressionConverter.convertExpression(request.getFilterExpression()) + ")";
+    }
+
+    @Override
+    public void afterPropertiesSet() {
+
+        if (!this.initializeSchema) {
+            return;
+        }
+
+        // If index already exists don't do anything
+        if (this.jedis.ftList().contains(this.config.indexName)) {
+            return;
+        }
+
+        String response = this.jedis.ftCreate(this.config.indexName,
+                FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.config.prefix), schemaFields());
+        if (!RESPONSE_OK.test(response)) {
+            String message = MessageFormat.format("Could not create index: {0}", response);
+            throw new RuntimeException(message);
+        }
+    }
+
+    private Iterable<SchemaField> schemaFields() {
+        Map<String, Object> vectorAttrs = new HashMap<>();
+        vectorAttrs.put("DIM", this.embeddingModel.dimensions());
+        vectorAttrs.put("DISTANCE_METRIC", DEFAULT_DISTANCE_METRIC);
+        vectorAttrs.put("TYPE", VECTOR_TYPE_FLOAT32);
+        List<SchemaField> fields = new ArrayList<>();
+        fields.add(TextField.of(jsonPath(this.config.contentFieldName)).as(this.config.contentFieldName).weight(1.0));
+        fields.add(VectorField.builder()
+                .fieldName(jsonPath(this.config.embeddingFieldName))
+                .algorithm(vectorAlgorithm())
+                .attributes(vectorAttrs)
+                .as(this.config.embeddingFieldName)
+                .build());
+
+        if (!CollectionUtils.isEmpty(this.config.metadataFields)) {
+            for (MetadataField field : this.config.metadataFields) {
+                fields.add(schemaField(field));
+            }
+        }
+        return fields;
+    }
+
+    private SchemaField schemaField(MetadataField field) {
+        String fieldName = jsonPath(field.name);
+        switch (field.fieldType) {
+            case NUMERIC:
+                return NumericField.of(fieldName).as(field.name);
+            case TAG:
+                return TagField.of(fieldName).as(field.name);
+            case TEXT:
+                return TextField.of(fieldName).as(field.name);
+            default:
+                throw new IllegalArgumentException(
+                        MessageFormat.format("Field {0} has unsupported type {1}", field.name, field.fieldType));
+        }
+    }
+
+    private VectorAlgorithm vectorAlgorithm() {
+        if (config.vectorAlgorithm == Algorithm.HSNW) {
+            return VectorAlgorithm.HNSW;
+        }
+        return VectorAlgorithm.FLAT;
+    }
+
+    private String jsonPath(String field) {
+        return JSON_PATH_PREFIX + field;
+    }
+
+    private static float[] toFloatArray(List<Double> embeddingDouble) {
+        float[] embeddingFloat = new float[embeddingDouble.size()];
+        int i = 0;
+        for (Double d : embeddingDouble) {
+            embeddingFloat[i++] = d.floatValue();
+        }
+        return embeddingFloat;
+    }
+
+}

二进制
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/resources/webapp/test/Fel.pdf


+ 4 - 0
yudao-server/src/main/resources/application.yaml

@@ -153,6 +153,10 @@ spring:
 
 spring:
   ai:
+    vectorstore:
+      redis:
+        index: default-index
+        prefix: "default:"
     qianfan: # 文心一言
       api-key: x0cuLZ7XsaTCU08vuJWO87Lg
       secret-key: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK