Kaynağa Gözat

完善 OAuth2ApproveServiceImplTest 单元测试

YunaiV 2 yıl önce
ebeveyn
işleme
65d2dffe1a

+ 9 - 6
yudao-framework/yudao-spring-boot-starter-test/src/main/java/cn/iocoder/yudao/framework/test/core/util/RandomUtils.java

@@ -2,13 +2,16 @@ package cn.iocoder.yudao.framework.test.core.util;
 
 import cn.hutool.core.util.ArrayUtil;
 import cn.hutool.core.util.RandomUtil;
+import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
-import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
 import uk.co.jemos.podam.api.PodamFactory;
 import uk.co.jemos.podam.api.PodamFactoryImpl;
 
 import java.lang.reflect.Type;
-import java.util.*;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.List;
+import java.util.Set;
 import java.util.function.Consumer;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -22,7 +25,6 @@ public class RandomUtils {
 
     private static final int RANDOM_STRING_LENGTH = 10;
 
-    private static final Set<String> TINYINT_FIELDS = SetUtils.asSet("type", "category");
     private static final int TINYINT_MAX = 127;
 
     private static final int RANDOM_DATE_MAX = 30;
@@ -41,9 +43,10 @@ public class RandomUtils {
             if (attributeMetadata.getAttributeName().equals("status")) {
                 return RandomUtil.randomEle(CommonStatusEnum.values()).getStatus();
             }
-            // 针对部分字段,使用 tinyint 范围
-            if (TINYINT_FIELDS.contains(attributeMetadata.getAttributeName())) {
-                return RandomUtil.randomInt(1, TINYINT_MAX + 1);
+            // 如果是 type、status 结尾的字段,返回 tinyint 范围
+            if (StrUtil.endWithAnyIgnoreCase(attributeMetadata.getAttributeName(),
+                    "type", "status", "category")) {
+                return RandomUtil.randomInt(0, TINYINT_MAX + 1);
             }
             return RandomUtil.randomInt();
         });

+ 9 - 5
yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/service/oauth2/OAuth2ApproveServiceImpl.java

@@ -6,7 +6,9 @@ import cn.iocoder.yudao.framework.common.util.date.DateUtils;
 import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2ApproveDO;
 import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2ClientDO;
 import cn.iocoder.yudao.module.system.dal.mysql.oauth2.OAuth2ApproveMapper;
+import com.google.common.annotations.VisibleForTesting;
 import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Transactional;
 import org.springframework.validation.annotation.Validated;
 
 import javax.annotation.Resource;
@@ -35,6 +37,7 @@ public class OAuth2ApproveServiceImpl implements OAuth2ApproveService {
     private OAuth2ApproveMapper oauth2ApproveMapper;
 
     @Override
+    @Transactional
     public boolean checkForPreApproval(Long userId, Integer userType, String clientId, Collection<String> requestedScopes) {
         // 第一步,基于 Client 的自动授权计算,如果 scopes 都在自动授权中,则返回 true 通过
         OAuth2ClientDO clientDO = oauth2ClientService.validOAuthClientFromCache(clientId);
@@ -49,14 +52,14 @@ public class OAuth2ApproveServiceImpl implements OAuth2ApproveService {
         }
 
         // 第二步,算上用户已经批准的授权。如果 scopes 都包含,则返回 true
-        List<OAuth2ApproveDO> approveDOs = oauth2ApproveMapper.selectListByUserIdAndUserTypeAndClientId(
-                userId, userType, clientId);
+        List<OAuth2ApproveDO> approveDOs = getApproveList(userId, userType, clientId);
         Set<String> scopes = convertSet(approveDOs, OAuth2ApproveDO::getScope,
-                o -> o.getApproved() && !DateUtils.isExpired(o.getExpiresTime())); // 只保留未过期
+                OAuth2ApproveDO::getApproved); // 只保留未过期的 + 同意
         return CollUtil.containsAll(scopes, requestedScopes);
     }
 
     @Override
+    @Transactional
     public boolean updateAfterApproval(Long userId, Integer userType, String clientId, Map<String, Boolean> requestedScopes) {
         // 如果 requestedScopes 为空,说明没有要求,则返回 true 通过
         if (CollUtil.isEmpty(requestedScopes)) {
@@ -83,8 +86,9 @@ public class OAuth2ApproveServiceImpl implements OAuth2ApproveService {
         return approveDOs;
     }
 
-    private void saveApprove(Long userId, Integer userType, String clientId,
-                             String scope, Boolean approved, Date expireTime) {
+    @VisibleForTesting
+    void saveApprove(Long userId, Integer userType, String clientId,
+                     String scope, Boolean approved, Date expireTime) {
         // 先更新
         OAuth2ApproveDO approveDO = new OAuth2ApproveDO().setUserId(userId).setUserType(userType)
                 .setClientId(clientId).setScope(scope).setApproved(approved).setExpiresTime(expireTime);

+ 267 - 0
yudao-module-system/yudao-module-system-biz/src/test/java/cn/iocoder/yudao/module/system/service/oauth2/OAuth2ApproveServiceImplTest.java

@@ -0,0 +1,267 @@
+package cn.iocoder.yudao.module.system.service.oauth2;
+
+import cn.hutool.core.util.ObjectUtil;
+import cn.iocoder.yudao.framework.common.enums.UserTypeEnum;
+import cn.iocoder.yudao.framework.common.util.date.DateUtils;
+import cn.iocoder.yudao.framework.test.core.ut.BaseDbUnitTest;
+import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2ApproveDO;
+import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2ClientDO;
+import cn.iocoder.yudao.module.system.dal.mysql.oauth2.OAuth2ApproveMapper;
+import org.assertj.core.util.Lists;
+import org.junit.jupiter.api.Test;
+import org.springframework.boot.test.mock.mockito.MockBean;
+import org.springframework.context.annotation.Import;
+
+import javax.annotation.Resource;
+import java.time.Duration;
+import java.util.*;
+
+import static cn.hutool.core.util.RandomUtil.*;
+import static cn.iocoder.yudao.framework.common.util.date.DateUtils.addTime;
+import static cn.iocoder.yudao.framework.test.core.util.AssertUtils.assertPojoEquals;
+import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.randomString;
+import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.*;
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.when;
+
+/**
+ * {@link OAuth2ApproveServiceImpl} 的单元测试类
+ *
+ * @author 芋道源码
+ */
+@Import(OAuth2ApproveServiceImpl.class)
+public class OAuth2ApproveServiceImplTest extends BaseDbUnitTest {
+
+    @Resource
+    private OAuth2ApproveServiceImpl oauth2ApproveService;
+
+    @Resource
+    private OAuth2ApproveMapper oauth2ApproveMapper;
+
+    @MockBean
+    private OAuth2ClientService oauth2ClientService;
+
+    @Test
+    public void checkForPreApproval_clientAutoApprove() {
+        // 准备参数
+        Long userId = randomLongId();
+        Integer userType = randomEle(UserTypeEnum.values()).getValue();
+        String clientId = randomString();
+        List<String> requestedScopes = Lists.newArrayList("read");
+        // mock 方法
+        when(oauth2ClientService.validOAuthClientFromCache(eq(clientId)))
+                .thenReturn(randomPojo(OAuth2ClientDO.class).setAutoApproveScopes(requestedScopes));
+
+        // 调用
+        boolean success = oauth2ApproveService.checkForPreApproval(userId, userType,
+                clientId, requestedScopes);
+        // 断言
+        assertTrue(success);
+        List<OAuth2ApproveDO> result = oauth2ApproveMapper.selectList();
+        assertEquals(1, result.size());
+        assertEquals(userId, result.get(0).getUserId());
+        assertEquals(userType, result.get(0).getUserType());
+        assertEquals(clientId, result.get(0).getClientId());
+        assertEquals("read", result.get(0).getScope());
+        assertTrue(result.get(0).getApproved());
+        assertFalse(DateUtils.isExpired(result.get(0).getExpiresTime()));
+    }
+
+    @Test
+    public void checkForPreApproval_approve() {
+        // 准备参数
+        Long userId = randomLongId();
+        Integer userType = randomEle(UserTypeEnum.values()).getValue();
+        String clientId = randomString();
+        List<String> requestedScopes = Lists.newArrayList("read");
+        // mock 方法
+        when(oauth2ClientService.validOAuthClientFromCache(eq(clientId)))
+                .thenReturn(randomPojo(OAuth2ClientDO.class).setAutoApproveScopes(null));
+        // mock 数据
+        OAuth2ApproveDO approve = randomPojo(OAuth2ApproveDO.class).setUserId(userId)
+                .setUserType(userType).setClientId(clientId).setScope("read")
+                .setExpiresTime(addTime(Duration.ofDays(1))).setApproved(true); // 同意
+        oauth2ApproveMapper.insert(approve);
+
+        // 调用
+        boolean success = oauth2ApproveService.checkForPreApproval(userId, userType,
+                clientId, requestedScopes);
+        // 断言
+        assertTrue(success);
+    }
+
+    @Test
+    public void checkForPreApproval_reject() {
+        // 准备参数
+        Long userId = randomLongId();
+        Integer userType = randomEle(UserTypeEnum.values()).getValue();
+        String clientId = randomString();
+        List<String> requestedScopes = Lists.newArrayList("read");
+        // mock 方法
+        when(oauth2ClientService.validOAuthClientFromCache(eq(clientId)))
+                .thenReturn(randomPojo(OAuth2ClientDO.class).setAutoApproveScopes(null));
+        // mock 数据
+        OAuth2ApproveDO approve = randomPojo(OAuth2ApproveDO.class).setUserId(userId)
+                .setUserType(userType).setClientId(clientId).setScope("read")
+                .setExpiresTime(addTime(Duration.ofDays(1))).setApproved(false); // 拒绝
+        oauth2ApproveMapper.insert(approve);
+
+        // 调用
+        boolean success = oauth2ApproveService.checkForPreApproval(userId, userType,
+                clientId, requestedScopes);
+        // 断言
+        assertFalse(success);
+    }
+
+    @Test
+    public void testUpdateAfterApproval_none() {
+        // 准备参数
+        Long userId = randomLongId();
+        Integer userType = randomEle(UserTypeEnum.values()).getValue();
+        String clientId = randomString();
+
+        // 调用
+        boolean success = oauth2ApproveService.updateAfterApproval(userId, userType, clientId,
+                null);
+        // 断言
+        assertTrue(success);
+        List<OAuth2ApproveDO> result = oauth2ApproveMapper.selectList();
+        assertEquals(0, result.size());
+    }
+
+    @Test
+    public void testUpdateAfterApproval_approved() {
+        // 准备参数
+        Long userId = randomLongId();
+        Integer userType = randomEle(UserTypeEnum.values()).getValue();
+        String clientId = randomString();
+        Map<String, Boolean> requestedScopes = new LinkedHashMap<>(); // 有序,方便判断
+        requestedScopes.put("read", true);
+        requestedScopes.put("write", false);
+        // mock 方法
+
+        // 调用
+        boolean success = oauth2ApproveService.updateAfterApproval(userId, userType, clientId,
+                requestedScopes);
+        // 断言
+        assertTrue(success);
+        List<OAuth2ApproveDO> result = oauth2ApproveMapper.selectList();
+        assertEquals(2, result.size());
+        // read
+        assertEquals(userId, result.get(0).getUserId());
+        assertEquals(userType, result.get(0).getUserType());
+        assertEquals(clientId, result.get(0).getClientId());
+        assertEquals("read", result.get(0).getScope());
+        assertTrue(result.get(0).getApproved());
+        assertFalse(DateUtils.isExpired(result.get(0).getExpiresTime()));
+        // write
+        assertEquals(userId, result.get(1).getUserId());
+        assertEquals(userType, result.get(1).getUserType());
+        assertEquals(clientId, result.get(1).getClientId());
+        assertEquals("write", result.get(1).getScope());
+        assertFalse(result.get(1).getApproved());
+        assertFalse(DateUtils.isExpired(result.get(1).getExpiresTime()));
+    }
+
+    @Test
+    public void testUpdateAfterApproval_reject() {
+        // 准备参数
+        Long userId = randomLongId();
+        Integer userType = randomEle(UserTypeEnum.values()).getValue();
+        String clientId = randomString();
+        Map<String, Boolean> requestedScopes = new LinkedHashMap<>();
+        requestedScopes.put("write", false);
+        // mock 方法
+
+        // 调用
+        boolean success = oauth2ApproveService.updateAfterApproval(userId, userType, clientId,
+                requestedScopes);
+        // 断言
+        assertFalse(success);
+        List<OAuth2ApproveDO> result = oauth2ApproveMapper.selectList();
+        assertEquals(1, result.size());
+        // write
+        assertEquals(userId, result.get(0).getUserId());
+        assertEquals(userType, result.get(0).getUserType());
+        assertEquals(clientId, result.get(0).getClientId());
+        assertEquals("write", result.get(0).getScope());
+        assertFalse(result.get(0).getApproved());
+        assertFalse(DateUtils.isExpired(result.get(0).getExpiresTime()));
+    }
+
+    @Test
+    public void testGetApproveList() {
+        // 准备参数
+        Long userId = 10L;
+        Integer userType = UserTypeEnum.ADMIN.getValue();
+        String clientId = randomString();
+        // mock 数据
+        OAuth2ApproveDO approve = randomPojo(OAuth2ApproveDO.class).setUserId(userId)
+                .setUserType(userType).setClientId(clientId).setExpiresTime(addTime(Duration.ofDays(1L)));
+        oauth2ApproveMapper.insert(approve); // 未过期
+        oauth2ApproveMapper.insert(ObjectUtil.clone(approve).setId(null)
+                .setExpiresTime(addTime(Duration.ofDays(-1L)))); // 已过期
+
+        // 调用
+        List<OAuth2ApproveDO> result = oauth2ApproveService.getApproveList(userId, userType, clientId);
+        // 断言
+        assertEquals(1, result.size());
+        assertPojoEquals(approve, result.get(0));
+    }
+
+    @Test
+    public void testSaveApprove_insert() {
+        // 准备参数
+        Long userId = randomLongId();
+        Integer userType = randomEle(UserTypeEnum.values()).getValue();
+        String clientId = randomString();
+        String scope = randomString();
+        Boolean approved = randomBoolean();
+        Date expireTime = randomDay(1, 30);
+        // mock 方法
+
+        // 调用
+        oauth2ApproveService.saveApprove(userId, userType, clientId,
+                scope, approved, expireTime);
+        // 断言
+        List<OAuth2ApproveDO> result = oauth2ApproveMapper.selectList();
+        assertEquals(1, result.size());
+        assertEquals(userId, result.get(0).getUserId());
+        assertEquals(userType, result.get(0).getUserType());
+        assertEquals(clientId, result.get(0).getClientId());
+        assertEquals(scope, result.get(0).getScope());
+        assertEquals(approved, result.get(0).getApproved());
+        assertEquals(expireTime, result.get(0).getExpiresTime());
+    }
+
+    @Test
+    public void testSaveApprove_update() {
+        // mock 数据
+        OAuth2ApproveDO approve = randomPojo(OAuth2ApproveDO.class);
+        oauth2ApproveMapper.insert(approve);
+        // 准备参数
+        Long userId = approve.getUserId();
+        Integer userType = approve.getUserType();
+        String clientId = approve.getClientId();
+        String scope = approve.getScope();
+        Boolean approved = randomBoolean();
+        Date expireTime = randomDay(1, 30);
+        // mock 方法
+
+        // 调用
+        oauth2ApproveService.saveApprove(userId, userType, clientId,
+                scope, approved, expireTime);
+        // 断言
+        List<OAuth2ApproveDO> result = oauth2ApproveMapper.selectList();
+        assertEquals(1, result.size());
+        assertEquals(approve.getId(), result.get(0).getId());
+        assertEquals(userId, result.get(0).getUserId());
+        assertEquals(userType, result.get(0).getUserType());
+        assertEquals(clientId, result.get(0).getClientId());
+        assertEquals(scope, result.get(0).getScope());
+        assertEquals(approved, result.get(0).getApproved());
+        assertEquals(expireTime, result.get(0).getExpiresTime());
+    }
+
+}

+ 1 - 0
yudao-module-system/yudao-module-system-biz/src/test/resources/sql/clean.sql

@@ -21,3 +21,4 @@ DELETE FROM "system_tenant";
 DELETE FROM "system_tenant_package";
 DELETE FROM "system_sensitive_word";
 DELETE FROM "system_oauth2_client";
+DELETE FROM "system_oauth2_approve";

+ 16 - 0
yudao-module-system/yudao-module-system-biz/src/test/resources/sql/create_tables.sql

@@ -495,3 +495,19 @@ CREATE TABLE IF NOT EXISTS "system_oauth2_client" (
   "deleted" bit NOT NULL DEFAULT FALSE,
   PRIMARY KEY ("id")
 ) COMMENT 'OAuth2 客户端表';
+
+CREATE TABLE IF NOT EXISTS "system_oauth2_approve" (
+  "id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY,
+  "user_id" bigint NOT NULL,
+  "user_type" tinyint NOT NULL,
+  "client_id" varchar NOT NULL,
+  "scope" varchar NOT NULL,
+  "approved" bit NOT NULL DEFAULT FALSE,
+  "expires_time" datetime NOT NULL,
+  "creator" varchar DEFAULT '',
+  "create_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP,
+  "updater" varchar DEFAULT '',
+  "update_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+  "deleted" bit NOT NULL DEFAULT FALSE,
+  PRIMARY KEY ("id")
+) COMMENT 'OAuth2 批准表';