Kaynağa Gözat

feat: 脚本兼容Python 3.9, 抽取更多公共代码; 更新使用文档

dhb52 10 ay önce
ebeveyn
işleme
6d19690bea
2 değiştirilmiş dosya ile 157 ekleme ve 151 silme
  1. 10 2
      sql/tools/README.md
  2. 147 149
      sql/tools/convertor.py

+ 10 - 2
sql/tools/README.md

@@ -50,8 +50,16 @@ TODO 暂未支持
 
 使用方式如下:
 
+安装依赖库
+
+```bash
+pip install simple-ddl-parser
+```
+
+执行如下命令打印生成 postgresql 的脚本内容,其他可选参数有:oracle, sqlserver
+
 ```Bash
-python3 convertor.py
+python3 convertor.py postgres
 ```
 
-然后,TODO
+程序将sql脚本打印到终端,可以重定向到临时文件tmp.sql, 确认无误后可以利用IDEA(专业版)进行格式化。

+ 147 - 149
sql/tools/convertor.py

@@ -6,11 +6,12 @@ Author: dhb52 (https://gitee.com/dhb52)
 pip install simple-ddl-parser
 """
 
+import argparse
 import pathlib
 import re
 import time
 from abc import ABC, abstractmethod
-from typing import Dict, Tuple
+from typing import Dict, Generator, Optional, Tuple, Union
 
 from simple_ddl_parser import DDLParser
 
@@ -60,12 +61,12 @@ class Convertor(ABC):
         self.table_script_list = re.findall(r"CREATE TABLE [^;]*;", self.content)
 
     @abstractmethod
-    def translate_type(self, type: str, size: None | int | Tuple[int]) -> str:
+    def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]) -> str:
         """字段类型转换
 
         Args:
             type (str): 字段类型
-            size (None | int | Tuple[int]): 字段长度描述, 如varchar(255), decimal(10,2)
+            size (Optional[Union[int, Tuple[int]]]): 字段长度描述, 如varchar(255), decimal(10,2)
 
         Returns:
             str: 类型定义
@@ -97,7 +98,7 @@ class Convertor(ABC):
         pass
 
     @abstractmethod
-    def gen_index(self, table_ddl: Dict) -> str:
+    def gen_index(self, ddl: Dict) -> str:
         """生成索引定义
 
         Args:
@@ -133,6 +134,55 @@ class Convertor(ABC):
         """
         pass
 
+    @staticmethod
+    def inserts(table_name: str, script_content: str) -> Generator:
+        PREFIX = f"INSERT INTO `{table_name}`"
+
+        # 收集 `table_name` 对应的 insert 语句
+        for line in script_content.split("\n"):
+            if line.startswith(PREFIX):
+                head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
+                head = head.strip().replace("`", "").lower()
+                tail = tail.strip().replace(r"\"", '"')
+                # tail = tail.replace("b'0'", "'0'").replace("b'1'", "'1'")
+                yield f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
+
+    @staticmethod
+    def index(ddl: Dict) -> Generator:
+        """生成索引定义
+
+        Args:
+            ddl (Dict): 表DDL
+
+        Yields:
+            Generator[str]: create index 语句
+        """
+
+        def generate_columns(columns):
+            keys = [
+                f"{col['name'].lower()}{' ' + col['order'].lower() if col['order'] != 'ASC' else ''}"
+                for col in columns[0]
+            ]
+            return ", ".join(keys)
+
+        for no, index in enumerate(ddl["index"], 1):
+            columns = generate_columns(index["columns"])
+            table_name = ddl["table_name"].lower()
+            yield f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns})"
+
+    @staticmethod
+    def filed_comments(table_sql: str) -> Generator:
+        for line in table_sql.split("\n"):
+            match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip())
+            if match:
+                field = match.group(1)
+                comment_string = match.group(2).replace("\\n", "\n")
+                yield field, comment_string
+
+    def table_comment(self, table_sql: str) -> str:
+        match = re.search(r"COMMENT \= '([^']+)';", table_sql)
+        return match.group(1) if match else None
+
     def print(self):
         """打印转换后的sql脚本到终端"""
         print(
@@ -192,7 +242,7 @@ class PostgreSQLConvertor(Convertor):
     def __init__(self, src):
         super().__init__(src, "PostgreSQL")
 
-    def translate_type(self, type, size):
+    def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
         """类型转换"""
 
         type = type.lower()
@@ -234,27 +284,30 @@ class PostgreSQLConvertor(Convertor):
 
         table_name = ddl["table_name"].lower()
         columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]]
+        filed_def_list = ",\n  ".join(columns)
         script = f"""-- ----------------------------
 -- Table structure for {table_name}
 -- ----------------------------
 DROP TABLE IF EXISTS {table_name};
 CREATE TABLE {table_name} (
-    {',\n  '.join(columns)}
+    {filed_def_list}
 );"""
 
         return script
 
-    def gen_comment(self, table_sql, table_name) -> str:
+    def gen_index(self, ddl: Dict) -> str:
+        return "\n".join(f"{script};" for script in self.index(ddl))
+
+    def gen_comment(self, table_sql: str, table_name: str) -> str:
         """生成字段及表的注释"""
 
         script = ""
-        for line in table_sql.split("\n"):
-            match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip())
-            if match:
-                script += f"COMMENT ON COLUMN {table_name}.{match.group(1)} IS '{match.group(2).replace('\\n', '\n')}';\n"
+        for field, comment_string in self.filed_comments(table_sql):
+            script += (
+                f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n"
+            )
 
-        match = re.search(r"COMMENT \= '([^']+)';", table_sql)
-        table_comment = match.group(1) if match else None
+        table_comment = self.table_comment(table_sql)
         if table_comment:
             script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n"
 
@@ -264,53 +317,21 @@ CREATE TABLE {table_name} (
         """生成主键定义"""
         return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n"
 
-    def gen_index(self, ddl) -> str:
-        """生成 index"""
-
-        def generate_columns(columns):
-            keys = [
-                f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}"
-                for col in columns[0]
-            ]
-            return ", ".join(keys)
-
-        script = ""
-        for no, index in enumerate(ddl["index"], 1):
-            columns = generate_columns(index["columns"])
-            table_name = ddl["table_name"].lower()
-            script += (
-                f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns});\n"
-            )
-
-        return script
-
-    def gen_insert(self, table_name) -> str:
+    def gen_insert(self, table_name: str) -> str:
         """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence"""
 
-        PREFIX = f"INSERT INTO `{table_name}`"
-
-        # 收集 `table_name` 对应的 insert 语句
-        inserts = []
-        for line in self.content.split("\n"):
-            if line.startswith(PREFIX):
-                head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
-                head = head.strip().replace("`", "").lower()
-                tail = tail.strip().replace(r"\"", '"')
-                script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
-                # bit(1)数据转换
-                script = script.replace("b'0'", "'0'").replace("b'1'", "'1'")
-                inserts.append(script)
-
+        inserts = list(Convertor.inserts(table_name, self.content))
         ## 生成 insert 脚本
         script = ""
         last_id = 0
         if inserts:
+            inserts_lines = "\n".join(inserts)
             script += f"""\n\n-- ----------------------------
 -- Records of {table_name.lower()}
 -- ----------------------------
 -- @formatter:off
 BEGIN;
-{'\n'.join(inserts)}
+{inserts_lines}
 COMMIT;
 -- @formatter:on"""
             match = re.search(r"VALUES \((\d+),", inserts[-1])
@@ -332,7 +353,7 @@ class OracleConvertor(Convertor):
     def __init__(self, src):
         super().__init__(src, "Oracle")
 
-    def translate_type(self, type, size: None | int | Tuple[int]):
+    def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
         """类型转换"""
         type = type.lower()
 
@@ -369,15 +390,19 @@ class OracleConvertor(Convertor):
             full_type = self.translate_type(type, col["size"])
             nullable = "NULL" if col["nullable"] else "NOT NULL"
             default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
-            return f"{'\"size\"' if name == "size" else name } {full_type} {default} {nullable}"
+            # Oracle 中 size 不能作为字段名
+            field_name = '"size"' if name == "size" else name
+            # Oracle DEFAULT 定义在 NULLABLE 之前
+            return f"{field_name} {full_type} {default} {nullable}"
 
         table_name = ddl["table_name"].lower()
         columns = [f"{generate_column(col).strip()}" for col in ddl["columns"]]
+        field_def_list = ",\n    ".join(columns)
         script = f"""-- ----------------------------
 -- Table structure for {table_name}
 -- ----------------------------
-CREATE TABLE {ddl['table_name'].lower()} (
-    {',\n    '.join(columns)}
+CREATE TABLE {table_name} (
+    {field_def_list}
 );"""
 
         # oracle INSERT '' 不能通过 NOT NULL 校验
@@ -385,72 +410,51 @@ CREATE TABLE {ddl['table_name'].lower()} (
 
         return script
 
-    def gen_comment(self, table_sql, table_name) -> str:
+    def gen_index(self, ddl: Dict) -> str:
+        return "\n".join(f"{script};" for script in self.index(ddl))
+
+    def gen_comment(self, table_sql: str, table_name: str) -> str:
         script = ""
-        for line in table_sql.split("\n"):
-            match = re.search(r"`([^`]+)`.* COMMENT '([^']+)'", line)
-            if match:
-                script += f"COMMENT ON COLUMN {table_name}.{match.group(1)} IS '{match.group(2).replace('\\n', '\n')}';\n"
+        for field, comment_string in self.filed_comments(table_sql):
+            script += (
+                f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n"
+            )
 
-        match = re.search(r"COMMENT \= '([^']+)';", table_sql)
-        table_comment = match.group(1) if match else None
+        table_comment = self.table_comment(table_sql)
         if table_comment:
-            script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';"
+            script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n"
 
         return script
 
-    def gen_pk(self, table_name) -> str:
+    def gen_pk(self, table_name: str) -> str:
         """生成主键定义"""
         return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n"
 
-    def gen_index(self, table_ddl) -> str:
-        """生成 INDEX 定义"""
-
-        def generate_columns(columns):
-            keys = [
-                f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}"
-                for col in columns[0]
-            ]
-            return ", ".join(keys)
-
-        script = ""
-        for no, index in enumerate(table_ddl["index"], 1):
-            columns = generate_columns(index["columns"])
-            table_name = table_ddl["table_name"].lower()
-            script += (
-                f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns});\n"
-            )
-        return script
+    def gen_index(self, ddl: Dict) -> str:
+        return "\n".join(f"{script};" for script in self.index(ddl))
 
-    def gen_insert(self, table_name) -> str:
+    def gen_insert(self, table_name: str) -> str:
         """拷贝 INSERT 语句"""
-        PREFIX = f"INSERT INTO `{table_name}`"
         inserts = []
-        for line in self.content.split("\n"):
-            if line.startswith(PREFIX):
-                head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
-                head = head.strip().replace("`", "").lower()
-                tail = tail.strip().replace(r"\"", '"')
-                script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
-                # bit(1)数据转换
-                script = script.replace("b'0'", "'0'").replace("b'1'", "'1'")
-                # 对日期数据添加 TO_DATE 转换
-                script = re.sub(
-                    r"('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}')",
-                    r"to_date(\g<1>, 'SYYYY-MM-DD HH24:MI:SS')",
-                    script,
-                )
-                inserts.append(script)
+        for insert_script in Convertor.inserts(table_name, self.content):
+            # 对日期数据添加 TO_DATE 转换
+            insert_script = re.sub(
+                r"('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}')",
+                r"to_date(\g<1>, 'SYYYY-MM-DD HH24:MI:SS')",
+                insert_script,
+            )
+            inserts.append(insert_script)
 
         ## 生成 insert 脚本
         script = ""
         last_id = 0
         if inserts:
+            inserts_lines = "\n".join(inserts)
             script += f"""\n\n-- ----------------------------
 -- Records of {table_name.lower()}
 -- ----------------------------
 -- @formatter:off
-{'\n'.join(inserts)}
+{inserts_lines}
 COMMIT;
 -- @formatter:on"""
             match = re.search(r"VALUES \((\d+),", inserts[-1])
@@ -476,7 +480,7 @@ class SQLServerConvertor(Convertor):
     def __init__(self, src):
         super().__init__(src, "Microsoft SQL Server")
 
-    def translate_type(self, type, size):
+    def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
         """类型转换"""
 
         type = type.lower()
@@ -507,7 +511,7 @@ class SQLServerConvertor(Convertor):
 
         def _generate_column(col):
             name = col["name"].lower()
-            if name == 'id':
+            if name == "id":
                 return "id bigint NOT NULL PRIMARY KEY IDENTITY"
             if name == "deleted":
                 return "deleted bit DEFAULT 0 NOT NULL"
@@ -520,35 +524,34 @@ class SQLServerConvertor(Convertor):
 
         table_name = ddl["table_name"].lower()
         columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]]
+        filed_def_list = ",\n    ".join(columns)
         script = f"""-- ----------------------------
 -- Table structure for {table_name}
 -- ----------------------------
 DROP TABLE IF EXISTS {table_name};
 CREATE TABLE {table_name} (
-    {',\n    '.join(columns)}
+    {filed_def_list}
 )
 GO"""
 
         return script
 
-    def gen_comment(self, table_sql, table_name) -> str:
+    def gen_comment(self, table_sql: str, table_name: str) -> str:
         """生成字段及表的注释"""
 
         script = ""
-        for line in table_sql.split("\n"):
-            match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip())
-            if match:
-                script += f"""EXEC sp_addextendedproperty
-    'MS_Description', N'{match.group(2).replace('\\n', '\n')}',
+
+        for field, comment_string in self.filed_comments(table_sql):
+            script += f"""EXEC sp_addextendedproperty
+    'MS_Description', N'{comment_string}',
     'SCHEMA', N'dbo',
     'TABLE', N'{table_name}',
-    'COLUMN', N'{match.group(1)}'
+    'COLUMN', N'{field}'
 GO
 
 """
 
-        match = re.search(r"COMMENT \= '([^']+)';", table_sql)
-        table_comment = match.group(1) if match else None
+        table_comment = self.table_comment(table_sql)
         if table_comment:
             script += f"""EXEC sp_addextendedproperty
     'MS_Description', N'{table_comment}',
@@ -557,55 +560,34 @@ GO
 GO
 
 """
-
         return script
 
-    def gen_pk(self, table_name) -> str:
+    def gen_pk(self, table_name: str) -> str:
         """生成主键定义"""
         return ""
 
-    def gen_index(self, ddl) -> str:
+    def gen_index(self, ddl: Dict) -> str:
         """生成 index"""
+        return "\n".join(f"{script}\nGO" for script in self.index(ddl))
 
-        def generate_columns(columns):
-            keys = [
-                f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}"
-                for col in columns[0]
-            ]
-            return ", ".join(keys)
-
-        script = ""
-        for no, index in enumerate(ddl["index"], 1):
-            columns = generate_columns(index["columns"])
-            table_name = ddl["table_name"].lower()
-            script += f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns})\nGO\n"
-
-        return script
-
-    def gen_insert(self, table_name) -> str:
+    def gen_insert(self, table_name: str) -> str:
         """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence"""
 
-        PREFIX = f"INSERT INTO `{table_name}`"
-
         # 收集 `table_name` 对应的 insert 语句
         inserts = []
-        for line in self.content.split("\n"):
-            if line.startswith(PREFIX):
-                head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
-                head = head.strip().replace("`", "").lower()
-                tail = tail.strip().replace(r"\"", '"')
-                # SQLServer: 字符串前加N,hack,是否存在替换字符串内容的风险
-                tail = tail.replace(", '", ", N'").replace("VALUES ('", "VALUES (N')")
-                script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
-                # bit(1)数据转换
-                script = script.replace("b'0'", "'0'").replace("b'1'", "'1'")
-                # 删除 insert 的结尾分号
-                script = re.sub(";$", r"\nGO", script)
-                inserts.append(script)
+        for insert_script in Convertor.inserts(table_name, self.content):
+            # SQLServer: 字符串前加N,hack,是否存在替换字符串内容的风险
+            insert_script = insert_script.replace(", '", ", N'").replace(
+                "VALUES ('", "VALUES (N')"
+            )
+            # 删除 insert 的结尾分号
+            insert_script = re.sub(";$", r"\nGO", insert_script)
+            inserts.append(insert_script)
 
         ## 生成 insert 脚本
         script = ""
         if inserts:
+            inserts_lines = "\n".join(inserts)
             script += f"""\n\n-- ----------------------------
 -- Records of {table_name.lower()}
 -- ----------------------------
@@ -614,7 +596,7 @@ BEGIN TRANSACTION
 GO
 SET IDENTITY_INSERT {table_name.lower()} ON
 GO
-{'\n'.join(inserts)}
+{inserts_lines}
 SET IDENTITY_INSERT {table_name.lower()} OFF
 GO
 COMMIT
@@ -625,10 +607,26 @@ GO
 
 
 def main():
-    sql_file = pathlib.Path('../mysql/ruoyi-vue-pro.sql').resolve().as_posix()
-    # convertor = PostgreSQLConvertor(sql_file)
-    # convertor = OracleConvertor(sql_file)
-    convertor = SQLServerConvertor(sql_file)
+    parser = argparse.ArgumentParser(description="芋道系统数据库转换工具")
+    parser.add_argument(
+        "type",
+        type=str,
+        help="目标数据库类型",
+        choices=["postgres", "oracle", "sqlserver"],
+    )
+    args = parser.parse_args()
+
+    sql_file = pathlib.Path("../mysql/ruoyi-vue-pro.sql").resolve().as_posix()
+    convertor = None
+    if args.type == "postgres":
+        convertor = PostgreSQLConvertor(sql_file)
+    elif args.type == "oracle":
+        convertor = OracleConvertor(sql_file)
+    elif args.type == "sqlserver":
+        convertor = SQLServerConvertor(sql_file)
+    else:
+        raise NotImplementedError(f"不支持目标数据库类型: {args.type}")
+
     convertor.print()