convertor.py 19 KB


  1. # encoding=utf8
  2. """芋道系统数据库迁移工具
  3. Author: dhb52 (https://gitee.com/dhb52)
  4. pip install simple-ddl-parser
  5. """
  6. import pathlib
  7. import re
  8. import time
  9. from abc import ABC, abstractmethod
  10. from typing import Dict, Tuple
  11. from simple_ddl_parser import DDLParser
  12. PREAMBLE = """/*
  13. Yudao Database Transfer Tool
  14. Source Server Type : MySQL
  15. Target Server Type : {db_type}
  16. Date: {date}
  17. */
  18. """
  19. def load_and_clean(sql_file: str) -> str:
  20. """加载源 SQL 文件,并清理内容方便下一步 ddl 解析
  21. Args:
  22. sql_file (str): sql文件路径
  23. Returns:
  24. str: 清理后的sql文件内容
  25. """
  26. REPLACE_PAIR_LIST = (
  27. (" CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci ", " "),
  28. (" KEY `", " INDEX `"),
  29. ("UNIQUE INDEX", "UNIQUE KEY"),
  30. ("b'0'", "'0'"),
  31. ("b'1'", "'1'"),
  32. )
  33. content = open(sql_file).read()
  34. for replace_pair in REPLACE_PAIR_LIST:
  35. content = content.replace(*replace_pair)
  36. content = re.sub(r"ENGINE.*COMMENT", "COMMENT", content)
  37. content = re.sub(r"ENGINE.*;", ";", content)
  38. return content
  39. class Convertor(ABC):
  40. def __init__(self, src: str, db_type) -> None:
  41. self.src = src
  42. self.db_type = db_type
  43. self.content = load_and_clean(self.src)
  44. self.table_script_list = re.findall(r"CREATE TABLE [^;]*;", self.content)
  45. @abstractmethod
  46. def translate_type(self, type: str, size: None | int | Tuple[int]) -> str:
  47. """字段类型转换
  48. Args:
  49. type (str): 字段类型
  50. size (None | int | Tuple[int]): 字段长度描述, 如varchar(255), decimal(10,2)
  51. Returns:
  52. str: 类型定义
  53. """
  54. pass
  55. @abstractmethod
  56. def gen_create(self, table_ddl: Dict) -> str:
  57. """生成 create 脚本
  58. Args:
  59. table_ddl (Dict): 表DDL
  60. Returns:
  61. str: 生成脚本
  62. """
  63. pass
  64. @abstractmethod
  65. def gen_pk(self, table_name: str) -> str:
  66. """生成主键定义
  67. Args:
  68. table_name (str): 表名
  69. Returns:
  70. str: 生成脚本
  71. """
  72. pass
  73. @abstractmethod
  74. def gen_index(self, table_ddl: Dict) -> str:
  75. """生成索引定义
  76. Args:
  77. table_ddl (Dict): 表DDL
  78. Returns:
  79. str: 生成脚本
  80. """
  81. pass
  82. @abstractmethod
  83. def gen_comment(self, table_sql: str, table_name: str) -> str:
  84. """生成字段/表注释
  85. Args:
  86. table_sql (str): 原始表SQL
  87. table_name (str): 表名
  88. Returns:
  89. str: 生成脚本
  90. """
  91. pass
  92. @abstractmethod
  93. def gen_insert(self, table_name: str) -> str:
  94. """生成 insert 语句块
  95. Args:
  96. table_name (str): 表名
  97. Returns:
  98. str: 生成脚本
  99. """
  100. pass
  101. def print(self):
  102. """打印转换后的sql脚本到终端"""
  103. print(
  104. PREAMBLE.format(
  105. db_type=self.db_type,
  106. date=time.strftime("%Y-%m-%d %H:%M:%S"),
  107. )
  108. )
  109. error_scripts = []
  110. for table_sql in self.table_script_list:
  111. ddl = DDLParser(table_sql.replace("`", "")).run()
  112. # 如果parse失败, 需要跟进
  113. if len(ddl) == 0:
  114. error_scripts.append(table_sql)
  115. continue
  116. table_ddl = ddl[0]
  117. table_name = table_ddl["table_name"]
  118. # 忽略 quartz 的内容
  119. if table_name.lower().startswith("qrtz"):
  120. continue
  121. # 为每个表生成个5个基本部分
  122. create = self.gen_create(table_ddl)
  123. pk = self.gen_pk(table_name)
  124. index = self.gen_index(table_ddl)
  125. comment = self.gen_comment(table_sql, table_name)
  126. inserts = self.gen_insert(table_name)
  127. # 组合当前表的DDL脚本
  128. script = f"""{create}
  129. {pk}
  130. {index}
  131. {comment}
  132. {inserts}
  133. """
  134. # 清理
  135. script = re.sub("\n{3,}", "\n\n", script).strip() + "\n"
  136. print(script)
  137. # 将parse失败的脚本打印出来
  138. if error_scripts:
  139. for script in error_scripts:
  140. print(script)
  141. class PostgreSQLConvertor(Convertor):
  142. def __init__(self, src):
  143. super().__init__(src, "PostgreSQL")
  144. def translate_type(self, type, size):
  145. """类型转换"""
  146. type = type.lower()
  147. if type == "varchar":
  148. return f"varchar({size})"
  149. if type == "int":
  150. return "int4"
  151. if type == "bigint" or type == "bigint unsigned":
  152. return "int8"
  153. if type == "datetime":
  154. return "timestamp"
  155. if type == "bit":
  156. return "bool"
  157. if type in ("tinyint", "smallint"):
  158. return "int2"
  159. if type == "text":
  160. return "text"
  161. if type in ("blob", "mediumblob"):
  162. return "bytea"
  163. if type == "decimal":
  164. return (
  165. f"numeric({','.join(str(s) for s in size)})" if len(size) else "numeric"
  166. )
  167. def gen_create(self, ddl: Dict) -> str:
  168. """生成 create"""
  169. def _generate_column(col):
  170. name = col["name"].lower()
  171. if name == "deleted":
  172. return "deleted int2 NOT NULL DEFAULT 0"
  173. type = col["type"].lower()
  174. full_type = self.translate_type(type, col["size"])
  175. nullable = "NULL" if col["nullable"] else "NOT NULL"
  176. default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
  177. return f"{name} {full_type} {nullable} {default}"
  178. table_name = ddl["table_name"].lower()
  179. columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]]
  180. script = f"""-- ----------------------------
  181. -- Table structure for {table_name}
  182. -- ----------------------------
  183. DROP TABLE IF EXISTS {table_name};
  184. CREATE TABLE {table_name} (
  185. {',\n '.join(columns)}
  186. );"""
  187. return script
  188. def gen_comment(self, table_sql, table_name) -> str:
  189. """生成字段及表的注释"""
  190. script = ""
  191. for line in table_sql.split("\n"):
  192. match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip())
  193. if match:
  194. script += f"COMMENT ON COLUMN {table_name}.{match.group(1)} IS '{match.group(2).replace('\\n', '\n')}';\n"
  195. match = re.search(r"COMMENT \= '([^']+)';", table_sql)
  196. table_comment = match.group(1) if match else None
  197. if table_comment:
  198. script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n"
  199. return script
  200. def gen_pk(self, table_name) -> str:
  201. """生成主键定义"""
  202. return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n"
  203. def gen_index(self, ddl) -> str:
  204. """生成 index"""
  205. def generate_columns(columns):
  206. keys = [
  207. f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}"
  208. for col in columns[0]
  209. ]
  210. return ", ".join(keys)
  211. script = ""
  212. for no, index in enumerate(ddl["index"], 1):
  213. columns = generate_columns(index["columns"])
  214. table_name = ddl["table_name"].lower()
  215. script += (
  216. f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns});\n"
  217. )
  218. return script
  219. def gen_insert(self, table_name) -> str:
  220. """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence"""
  221. PREFIX = f"INSERT INTO `{table_name}`"
  222. # 收集 `table_name` 对应的 insert 语句
  223. inserts = []
  224. for line in self.content.split("\n"):
  225. if line.startswith(PREFIX):
  226. head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
  227. head = head.strip().replace("`", "").lower()
  228. tail = tail.strip().replace(r"\"", '"')
  229. script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
  230. # bit(1)数据转换
  231. script = script.replace("b'0'", "'0'").replace("b'1'", "'1'")
  232. inserts.append(script)
  233. ## 生成 insert 脚本
  234. script = ""
  235. last_id = 0
  236. if inserts:
  237. script += f"""\n\n-- ----------------------------
  238. -- Records of {table_name.lower()}
  239. -- ----------------------------
  240. -- @formatter:off
  241. BEGIN;
  242. {'\n'.join(inserts)}
  243. COMMIT;
  244. -- @formatter:on"""
  245. match = re.search(r"VALUES \((\d+),", inserts[-1])
  246. if match:
  247. last_id = int(match.group(1))
  248. # 生成 Sequence
  249. script += (
  250. "\n\n"
  251. + f"""DROP SEQUENCE IF EXISTS {table_name}_seq;
  252. CREATE SEQUENCE {table_name}_seq
  253. START {last_id + 1};"""
  254. )
  255. return script
  256. class OracleConvertor(Convertor):
  257. def __init__(self, src):
  258. super().__init__(src, "Oracle")
  259. def translate_type(self, type, size: None | int | Tuple[int]):
  260. """类型转换"""
  261. type = type.lower()
  262. if type == "varchar":
  263. return f"varchar2({size if size < 4000 else 4000})"
  264. if type == "int":
  265. return "number"
  266. if type == "bigint" or type == "bigint unsigned":
  267. return "number"
  268. if type == "datetime":
  269. return "date"
  270. if type == "bit":
  271. return "number(1,0)"
  272. if type in ("tinyint", "smallint"):
  273. return "smallint"
  274. if type == "text":
  275. return "clob"
  276. if type in ("blob", "mediumblob"):
  277. return "blob"
  278. if type == "decimal":
  279. return (
  280. f"number({','.join(str(s) for s in size)})" if len(size) else "number"
  281. )
  282. def gen_create(self, ddl) -> str:
  283. """生成 CREATE 语句"""
  284. def generate_column(col):
  285. name = col["name"].lower()
  286. if name == "deleted":
  287. return "deleted number(1,0) DEFAULT 0 NOT NULL"
  288. type = col["type"].lower()
  289. full_type = self.translate_type(type, col["size"])
  290. nullable = "NULL" if col["nullable"] else "NOT NULL"
  291. default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
  292. return f"{'\"size\"' if name == "size" else name } {full_type} {default} {nullable}"
  293. table_name = ddl["table_name"].lower()
  294. columns = [f"{generate_column(col).strip()}" for col in ddl["columns"]]
  295. script = f"""-- ----------------------------
  296. -- Table structure for {table_name}
  297. -- ----------------------------
  298. CREATE TABLE {ddl['table_name'].lower()} (
  299. {',\n '.join(columns)}
  300. );"""
  301. # oracle INSERT '' 不能通过 NOT NULL 校验
  302. script = script.replace("DEFAULT '' NOT NULL", "DEFAULT '' NULL")
  303. return script
  304. def gen_comment(self, table_sql, table_name) -> str:
  305. script = ""
  306. for line in table_sql.split("\n"):
  307. match = re.search(r"`([^`]+)`.* COMMENT '([^']+)'", line)
  308. if match:
  309. script += f"COMMENT ON COLUMN {table_name}.{match.group(1)} IS '{match.group(2).replace('\\n', '\n')}';\n"
  310. match = re.search(r"COMMENT \= '([^']+)';", table_sql)
  311. table_comment = match.group(1) if match else None
  312. if table_comment:
  313. script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';"
  314. return script
  315. def gen_pk(self, table_name) -> str:
  316. """生成主键定义"""
  317. return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n"
  318. def gen_index(self, table_ddl) -> str:
  319. """生成 INDEX 定义"""
  320. def generate_columns(columns):
  321. keys = [
  322. f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}"
  323. for col in columns[0]
  324. ]
  325. return ", ".join(keys)
  326. script = ""
  327. for no, index in enumerate(table_ddl["index"], 1):
  328. columns = generate_columns(index["columns"])
  329. table_name = table_ddl["table_name"].lower()
  330. script += (
  331. f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns});\n"
  332. )
  333. return script
  334. def gen_insert(self, table_name) -> str:
  335. """拷贝 INSERT 语句"""
  336. PREFIX = f"INSERT INTO `{table_name}`"
  337. inserts = []
  338. for line in self.content.split("\n"):
  339. if line.startswith(PREFIX):
  340. head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
  341. head = head.strip().replace("`", "").lower()
  342. tail = tail.strip().replace(r"\"", '"')
  343. script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
  344. # bit(1)数据转换
  345. script = script.replace("b'0'", "'0'").replace("b'1'", "'1'")
  346. # 对日期数据添加 TO_DATE 转换
  347. script = re.sub(
  348. r"('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}')",
  349. r"to_date(\g<1>, 'SYYYY-MM-DD HH24:MI:SS')",
  350. script,
  351. )
  352. inserts.append(script)
  353. ## 生成 insert 脚本
  354. script = ""
  355. last_id = 0
  356. if inserts:
  357. script += f"""\n\n-- ----------------------------
  358. -- Records of {table_name.lower()}
  359. -- ----------------------------
  360. -- @formatter:off
  361. {'\n'.join(inserts)}
  362. COMMIT;
  363. -- @formatter:on"""
  364. match = re.search(r"VALUES \((\d+),", inserts[-1])
  365. if match:
  366. last_id = int(match.group(1))
  367. # 生成 Sequence
  368. script += f"""
  369. CREATE SEQUENCE {table_name}_seq
  370. START WITH {last_id + 1};"""
  371. return script
  372. class SQLServerConvertor(Convertor):
  373. """_summary_
  374. Args:
  375. Convertor (_type_): _description_
  376. """
  377. def __init__(self, src):
  378. super().__init__(src, "Microsoft SQL Server")
  379. def translate_type(self, type, size):
  380. """类型转换"""
  381. type = type.lower()
  382. if type == "varchar":
  383. return f"nvarchar({size if size < 4000 else 4000})"
  384. if type == "int":
  385. return "int"
  386. if type == "bigint" or type == "bigint unsigned":
  387. return "bigint"
  388. if type == "datetime":
  389. return "datetime2"
  390. if type == "bit":
  391. return "varchar(1)"
  392. if type in ("tinyint", "smallint"):
  393. return "tinyint"
  394. if type == "text":
  395. return "nvarchar(max)"
  396. if type in ("blob", "mediumblob"):
  397. return "varbinary(max)"
  398. if type == "decimal":
  399. return (
  400. f"numeric({','.join(str(s) for s in size)})" if len(size) else "numeric"
  401. )
  402. def gen_create(self, ddl: Dict) -> str:
  403. """生成 create"""
  404. def _generate_column(col):
  405. name = col["name"].lower()
  406. if name == 'id':
  407. return "id bigint NOT NULL PRIMARY KEY IDENTITY"
  408. if name == "deleted":
  409. return "deleted bit DEFAULT 0 NOT NULL"
  410. type = col["type"].lower()
  411. full_type = self.translate_type(type, col["size"])
  412. nullable = "NULL" if col["nullable"] else "NOT NULL"
  413. default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
  414. return f"{name} {full_type} {default} {nullable}"
  415. table_name = ddl["table_name"].lower()
  416. columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]]
  417. script = f"""-- ----------------------------
  418. -- Table structure for {table_name}
  419. -- ----------------------------
  420. DROP TABLE IF EXISTS {table_name};
  421. CREATE TABLE {table_name} (
  422. {',\n '.join(columns)}
  423. )
  424. GO"""
  425. return script
  426. def gen_comment(self, table_sql, table_name) -> str:
  427. """生成字段及表的注释"""
  428. script = ""
  429. for line in table_sql.split("\n"):
  430. match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip())
  431. if match:
  432. script += f"""EXEC sp_addextendedproperty
  433. 'MS_Description', N'{match.group(2).replace('\\n', '\n')}',
  434. 'SCHEMA', N'dbo',
  435. 'TABLE', N'{table_name}',
  436. 'COLUMN', N'{match.group(1)}'
  437. GO
  438. """
  439. match = re.search(r"COMMENT \= '([^']+)';", table_sql)
  440. table_comment = match.group(1) if match else None
  441. if table_comment:
  442. script += f"""EXEC sp_addextendedproperty
  443. 'MS_Description', N'{table_comment}',
  444. 'SCHEMA', N'dbo',
  445. 'TABLE', N'{table_name}'
  446. GO
  447. """
  448. return script
  449. def gen_pk(self, table_name) -> str:
  450. """生成主键定义"""
  451. return ""
  452. def gen_index(self, ddl) -> str:
  453. """生成 index"""
  454. def generate_columns(columns):
  455. keys = [
  456. f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}"
  457. for col in columns[0]
  458. ]
  459. return ", ".join(keys)
  460. script = ""
  461. for no, index in enumerate(ddl["index"], 1):
  462. columns = generate_columns(index["columns"])
  463. table_name = ddl["table_name"].lower()
  464. script += f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns})\nGO\n"
  465. return script
  466. def gen_insert(self, table_name) -> str:
  467. """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence"""
  468. PREFIX = f"INSERT INTO `{table_name}`"
  469. # 收集 `table_name` 对应的 insert 语句
  470. inserts = []
  471. for line in self.content.split("\n"):
  472. if line.startswith(PREFIX):
  473. head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
  474. head = head.strip().replace("`", "").lower()
  475. tail = tail.strip().replace(r"\"", '"')
  476. # SQLServer: 字符串前加N,hack,是否存在替换字符串内容的风险
  477. tail = tail.replace(", '", ", N'").replace("VALUES ('", "VALUES (N')")
  478. script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
  479. # bit(1)数据转换
  480. script = script.replace("b'0'", "'0'").replace("b'1'", "'1'")
  481. # 删除 insert 的结尾分号
  482. script = re.sub(";$", r"\nGO", script)
  483. inserts.append(script)
  484. ## 生成 insert 脚本
  485. script = ""
  486. if inserts:
  487. script += f"""\n\n-- ----------------------------
  488. -- Records of {table_name.lower()}
  489. -- ----------------------------
  490. -- @formatter:off
  491. BEGIN TRANSACTION
  492. GO
  493. SET IDENTITY_INSERT {table_name.lower()} ON
  494. GO
  495. {'\n'.join(inserts)}
  496. SET IDENTITY_INSERT {table_name.lower()} OFF
  497. GO
  498. COMMIT
  499. GO
  500. -- @formatter:on"""
  501. return script
  502. def main():
  503. sql_file = pathlib.Path('../mysql/ruoyi-vue-pro.sql').resolve().as_posix()
  504. # convertor = PostgreSQLConvertor(sql_file)
  505. # convertor = OracleConvertor(sql_file)
  506. convertor = SQLServerConvertor(sql_file)
  507. convertor.print()
  508. if __name__ == "__main__":
  509. main()