新增 crud 代码自动生成
This commit is contained in:
parent
65f92947f5
commit
7eb590a697
0
kinit-api/scripts/crud_generate/__init__.py
Normal file
0
kinit-api/scripts/crud_generate/__init__.py
Normal file
166
kinit-api/scripts/crud_generate/main.py
Normal file
166
kinit-api/scripts/crud_generate/main.py
Normal file
@ -0,0 +1,166 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2022/12/9 15:27
|
||||
# @File : main.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 简要说明
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
from typing import Type
|
||||
from application.settings import BASE_DIR
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from core.database import Base
|
||||
from scripts.crud_generate.utils.generate_base import GenerateBase
|
||||
from scripts.crud_generate.utils.schema_generate import SchemaGenerate
|
||||
from scripts.crud_generate.utils.params_generate import ParamsGenerate
|
||||
from scripts.crud_generate.utils.dal_generate import DalGenerate
|
||||
from scripts.crud_generate.utils.view_generate import ViewGenerate
|
||||
|
||||
|
||||
class CrudGenerate(GenerateBase):
|
||||
|
||||
APPS_ROOT = os.path.join(BASE_DIR, "apps")
|
||||
SCRIPT_DIR = os.path.join(BASE_DIR, 'scripts', 'crud_generate')
|
||||
|
||||
def __init__(self, model: Type[Base], zh_name: str, en_name: str = None):
|
||||
"""
|
||||
初始化工作
|
||||
:param model: 提前定义好的 ORM 模型
|
||||
:param zh_name: 功能中文名称,主要用于描述、注释
|
||||
:param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class
|
||||
en_name 例子:
|
||||
如果 en_name 由多个单词组成那么请使用 _ 下划线拼接
|
||||
在命名文件名称时,会执行使用 _ 下划线名称
|
||||
在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase)
|
||||
在命名 url 时,会将下划线转换为 /
|
||||
"""
|
||||
self.model = model
|
||||
self.zh_name = zh_name
|
||||
# model 文件的地址
|
||||
self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__]))
|
||||
# model 文件 app 路径
|
||||
self.app_dir_path = self.model_file_path.parent.parent
|
||||
# schemas 目录地址
|
||||
self.schemas_dir_path = self.app_dir_path / "schemas"
|
||||
# params 目录地址
|
||||
self.params_dir_path = self.app_dir_path / "params"
|
||||
# crud 文件地址
|
||||
self.crud_file_path = self.app_dir_path / "crud.py"
|
||||
# view 文件地址
|
||||
self.view_file_path = self.app_dir_path / "views.py"
|
||||
|
||||
if en_name:
|
||||
self.en_name = en_name
|
||||
else:
|
||||
self.en_name = self.model.__name__
|
||||
|
||||
self.schema_file_path = self.schemas_dir_path / f"{self.en_name}.py"
|
||||
self.param_file_path = self.params_dir_path / f"{self.en_name}.py"
|
||||
|
||||
self.base_class_name = self.snake_to_camel(self.en_name)
|
||||
self.schema_simple_out_class_name = f"{self.base_class_name}SimpleOut"
|
||||
self.dal_class_name = f"{self.base_class_name}Dal"
|
||||
self.param_class_name = f"{self.base_class_name}Params"
|
||||
|
||||
def generate_codes(self):
|
||||
"""
|
||||
生成代码, 不做实际操作,只是将代码打印出来
|
||||
:return:
|
||||
"""
|
||||
print(f"==========================={self.schema_file_path} 代码内容=================================")
|
||||
schema = SchemaGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.schema_file_path,
|
||||
self.schemas_dir_path,
|
||||
self.base_class_name,
|
||||
self.schema_simple_out_class_name
|
||||
)
|
||||
print(schema.generate_code())
|
||||
|
||||
print(f"==========================={self.dal_class_name} 代码内容=================================")
|
||||
dal = DalGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.dal_class_name,
|
||||
self.schema_simple_out_class_name
|
||||
)
|
||||
print(dal.generate_code())
|
||||
|
||||
print(f"==========================={self.param_file_path} 代码内容=================================")
|
||||
params = ParamsGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.params_dir_path,
|
||||
self.param_file_path,
|
||||
self.param_class_name
|
||||
)
|
||||
print(params.generate_code())
|
||||
|
||||
print(f"==========================={self.view_file_path} 代码内容=================================")
|
||||
view = ViewGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.base_class_name,
|
||||
self.schema_simple_out_class_name,
|
||||
self.dal_class_name,
|
||||
self.param_class_name
|
||||
)
|
||||
print(view.generate_code())
|
||||
|
||||
def main(self):
|
||||
"""
|
||||
开始生成 crud 代码,并直接写入到项目中,目前还未实现
|
||||
1. 生成 schemas 代码
|
||||
2. 生成 dal 代码
|
||||
3. 生成 params 代码
|
||||
4. 生成 views 代码
|
||||
:return:
|
||||
"""
|
||||
schema = SchemaGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.schema_file_path,
|
||||
self.schemas_dir_path,
|
||||
self.base_class_name,
|
||||
self.schema_simple_out_class_name
|
||||
)
|
||||
schema.write_generate_code()
|
||||
|
||||
dal = DalGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.dal_class_name,
|
||||
self.schema_simple_out_class_name
|
||||
)
|
||||
dal.write_generate_code()
|
||||
|
||||
params = ParamsGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.params_dir_path,
|
||||
self.param_file_path,
|
||||
self.param_class_name
|
||||
)
|
||||
params.write_generate_code()
|
||||
|
||||
view = ViewGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.base_class_name,
|
||||
self.schema_simple_out_class_name,
|
||||
self.dal_class_name,
|
||||
self.param_class_name
|
||||
)
|
||||
view.write_generate_code()
|
0
kinit-api/scripts/crud_generate/utils/__init__.py
Normal file
0
kinit-api/scripts/crud_generate/utils/__init__.py
Normal file
106
kinit-api/scripts/crud_generate/utils/dal_generate.py
Normal file
106
kinit-api/scripts/crud_generate/utils/dal_generate.py
Normal file
@ -0,0 +1,106 @@
|
||||
import inspect
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
from core.database import Base
|
||||
from .generate_base import GenerateBase
|
||||
|
||||
|
||||
class DalGenerate(GenerateBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Type[Base],
|
||||
zh_name: str,
|
||||
en_name: str,
|
||||
dal_class_name: str,
|
||||
schema_simple_out_class_name: str
|
||||
):
|
||||
"""
|
||||
初始化工作
|
||||
:param model: 提前定义好的 ORM 模型
|
||||
:param zh_name: 功能中文名称,主要用于描述、注释
|
||||
:param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class
|
||||
en_name 例子:
|
||||
如果 en_name 由多个单词组成那么请使用 _ 下划线拼接
|
||||
在命名文件名称时,会执行使用 _ 下划线名称
|
||||
在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase)
|
||||
在命名 url 时,会将下划线转换为 /
|
||||
:param dal_class_name:
|
||||
:param schema_simple_out_class_name:
|
||||
"""
|
||||
self.model = model
|
||||
self.dal_class_name = dal_class_name
|
||||
self.schema_simple_out_class_name = schema_simple_out_class_name
|
||||
self.zh_name = zh_name
|
||||
self.en_name = en_name
|
||||
# model 文件的地址
|
||||
self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__]))
|
||||
# model 文件 app 路径
|
||||
self.app_dir_path = self.model_file_path.parent.parent
|
||||
# crud 文件地址
|
||||
self.crud_file_path = self.app_dir_path / "crud.py"
|
||||
|
||||
def write_generate_code(self):
|
||||
"""
|
||||
生成 crud 文件,以及代码内容
|
||||
:return:
|
||||
"""
|
||||
if self.crud_file_path.exists():
|
||||
codes = self.file_code_split_module(self.crud_file_path)
|
||||
if codes:
|
||||
print(f"==========dal 文件已存在并已有代码内容,正在追加新代码============")
|
||||
if not codes[0]:
|
||||
# 无文件注释则添加文件注释
|
||||
codes[0] = self.generate_file_desc(self.crud_file_path.name, "1.0", "数据访问层")
|
||||
codes[1] = self.merge_dictionaries(codes[1], self.get_base_module_config())
|
||||
codes[2] += self.get_base_code_content()
|
||||
code = ''
|
||||
code += codes[0]
|
||||
code += self.generate_modules_code(codes[1])
|
||||
code += codes[2]
|
||||
self.crud_file_path.write_text(code, "utf-8")
|
||||
print(f"=================dal 代码已创建完成=======================")
|
||||
return
|
||||
self.crud_file_path.touch()
|
||||
code = self.generate_code()
|
||||
self.crud_file_path.write_text(code, "utf-8")
|
||||
print(f"===========================dal 代码创建完成=================================")
|
||||
|
||||
def generate_code(self):
|
||||
"""
|
||||
代码生成
|
||||
:return:
|
||||
"""
|
||||
code = self.generate_file_desc(self.crud_file_path.name, "1.0", "数据访问层")
|
||||
code += self.generate_modules_code(self.get_base_module_config())
|
||||
code += self.get_base_code_content()
|
||||
return code
|
||||
|
||||
@staticmethod
|
||||
def get_base_module_config():
|
||||
"""
|
||||
获取基础模块导入配置
|
||||
:return:
|
||||
"""
|
||||
modules = {
|
||||
"sqlalchemy.ext.asyncio": ['AsyncSession'],
|
||||
"core.crud": ["DalBase"],
|
||||
".": ["models", "schemas"],
|
||||
}
|
||||
return modules
|
||||
|
||||
def get_base_code_content(self):
|
||||
"""
|
||||
获取基础代码内容
|
||||
:return:
|
||||
"""
|
||||
base_code = f"\n\nclass {self.dal_class_name}(DalBase):\n"
|
||||
base_code += "\n\tdef __init__(self, db: AsyncSession):"
|
||||
base_code += f"\n\t\tsuper({self.dal_class_name}, self).__init__()"
|
||||
base_code += f"\n\t\tself.db = db"
|
||||
base_code += f"\n\t\tself.model = models.{self.model.__name__}"
|
||||
base_code += f"\n\t\tself.schema = schemas.{self.schema_simple_out_class_name}"
|
||||
base_code += "\n"
|
||||
return base_code.replace("\t", " ")
|
||||
|
185
kinit-api/scripts/crud_generate/utils/generate_base.py
Normal file
185
kinit-api/scripts/crud_generate/utils/generate_base.py
Normal file
@ -0,0 +1,185 @@
|
||||
import datetime
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class GenerateBase:
|
||||
|
||||
@staticmethod
|
||||
def camel_to_snake(name: str) -> str:
|
||||
"""
|
||||
将大驼峰命名(CamelCase)转换为下划线命名(snake_case)
|
||||
在大写字母前添加一个空格,然后将字符串分割并用下划线拼接
|
||||
:param name: 大驼峰命名(CamelCase)
|
||||
:return:
|
||||
"""
|
||||
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
||||
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
||||
|
||||
@staticmethod
|
||||
def snake_to_camel(name: str) -> str:
|
||||
"""
|
||||
将下划线命名(snake_case)转换为大驼峰命名(CamelCase)
|
||||
根据下划线分割,然后将字符串转为第一个字符大写后拼接
|
||||
:param name: 下划线命名(snake_case)
|
||||
:return:
|
||||
"""
|
||||
# 按下划线分割字符串
|
||||
words = name.split('_')
|
||||
# 将每个单词的首字母大写,然后拼接
|
||||
return ''.join(word.capitalize() for word in words)
|
||||
|
||||
@staticmethod
|
||||
def generate_file_desc(filename: str, version: str = '1.0', desc: str = '') -> str:
|
||||
"""
|
||||
生成文件注释
|
||||
:param filename:
|
||||
:param version:
|
||||
:param desc:
|
||||
:return:
|
||||
"""
|
||||
code = '#!/usr/bin/python\n# -*- coding: utf-8 -*-'
|
||||
code += f"\n# @version : {version}"
|
||||
code += f"\n# @Create Time : {datetime.datetime.now().strftime('%Y/%m/%d %H:%M')}"
|
||||
code += f"\n# @File : {filename}"
|
||||
code += f"\n# @IDE : PyCharm"
|
||||
code += f"\n# @desc : {desc}"
|
||||
code += f"\n"
|
||||
return code
|
||||
|
||||
@staticmethod
|
||||
def generate_modules_code(modules: dict[str, list]) -> str:
|
||||
"""
|
||||
生成模块导入代码
|
||||
:param modules: 导入得模块
|
||||
:return:
|
||||
"""
|
||||
code = "\n"
|
||||
args = modules.pop("args", [])
|
||||
for k, v in modules.items():
|
||||
code += f"from {k} import {', '.join(v)}\n"
|
||||
if args:
|
||||
code += f"import {', '.join(args)}\n"
|
||||
return code
|
||||
|
||||
@staticmethod
|
||||
def update_init_file(init_file: Path, code: str):
|
||||
"""
|
||||
__init__ 文件添加导入内容
|
||||
:param init_file:
|
||||
:param code:
|
||||
:return:
|
||||
"""
|
||||
content = init_file.read_text()
|
||||
if content and code in content:
|
||||
return
|
||||
if content:
|
||||
if content.endswith("\n"):
|
||||
with init_file.open("a+", encoding="utf-8") as f:
|
||||
f.write(f"{code}\n")
|
||||
else:
|
||||
with init_file.open("a+", encoding="utf-8") as f:
|
||||
f.write(f"\n{code}\n")
|
||||
else:
|
||||
init_file.write_text(f"{code}\n", encoding="utf-8")
|
||||
|
||||
@staticmethod
|
||||
def module_code_to_dict(code: str) -> dict:
|
||||
"""
|
||||
将 from import 语句代码转为 dict 格式
|
||||
:param code:
|
||||
:return:
|
||||
"""
|
||||
# 分解代码为单行
|
||||
lines = code.strip().split('\n')
|
||||
|
||||
# 初始化字典
|
||||
modules = {}
|
||||
|
||||
# 遍历每行代码
|
||||
for line in lines:
|
||||
# 处理 'from ... import ...' 类型的导入
|
||||
if line.startswith('from'):
|
||||
parts = line.split(' import ')
|
||||
module = parts[0][5:] # 移除 'from ' 并获取模块路径
|
||||
imports = parts[1].split(',') # 使用逗号分割导入项
|
||||
imports = [item.strip() for item in imports] # 移除多余空格
|
||||
if module in modules:
|
||||
modules[module].extend(imports)
|
||||
else:
|
||||
modules[module] = imports
|
||||
|
||||
# 处理 'import ...' 类型的导入
|
||||
elif line.startswith('import'):
|
||||
imports = line.split('import ')[1]
|
||||
# 分割多个导入项
|
||||
imports = imports.split(', ')
|
||||
for imp in imports:
|
||||
# 处理直接导入的模块
|
||||
modules.setdefault('args', []).append(imp)
|
||||
return modules
|
||||
|
||||
@classmethod
|
||||
def file_code_split_module(cls, file: Path) -> list:
|
||||
"""
|
||||
文件代码内容拆分,分为以下三部分
|
||||
1. 文件开头的注释。
|
||||
2. 全局层面的from import语句。该代码格式会被转换为 dict 格式
|
||||
3. 其他代码内容。
|
||||
:param file:
|
||||
:return:
|
||||
"""
|
||||
content = file.read_text(encoding="utf-8")
|
||||
if not content:
|
||||
return []
|
||||
lines = content.split('\n')
|
||||
part1 = [] # 文件开头注释
|
||||
part2 = [] # from import 语句
|
||||
part3 = [] # 其他代码内容
|
||||
|
||||
# 标记是否已超过注释部分
|
||||
past_comments = False
|
||||
|
||||
for line in lines:
|
||||
# 检查是否为注释行
|
||||
if line.startswith("#") and not past_comments:
|
||||
part1.append(line)
|
||||
else:
|
||||
# 标记已超过注释部分
|
||||
past_comments = True
|
||||
# 检查是否为 from import 语句
|
||||
if line.startswith("from ") or line.startswith("import "):
|
||||
part2.append(line)
|
||||
else:
|
||||
part3.append(line)
|
||||
|
||||
part2 = cls.module_code_to_dict('\n'.join(part2))
|
||||
|
||||
return ['\n'.join(part1), part2, '\n'.join(part3)]
|
||||
|
||||
@staticmethod
|
||||
def merge_dictionaries(dict1, dict2):
|
||||
"""
|
||||
合并两个键为字符串、值为列表的字典
|
||||
:param dict1:
|
||||
:param dict2:
|
||||
:return:
|
||||
"""
|
||||
# 初始化结果字典
|
||||
merged_dict = {}
|
||||
|
||||
# 合并两个字典中的键值对
|
||||
for key in set(dict1) | set(dict2): # 获取两个字典的键的并集
|
||||
merged_dict[key] = list(set(dict1.get(key, []) + dict2.get(key, [])))
|
||||
|
||||
return merged_dict
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_modules = {
|
||||
"sqlalchemy.ext.asyncio": ['AsyncSession'],
|
||||
"core.crud": ["DalBase"],
|
||||
".": ["models", "schemas"],
|
||||
"args": ["test", "test1"]
|
||||
}
|
||||
print(GenerateBase.generate_modules_code(_modules))
|
82
kinit-api/scripts/crud_generate/utils/params_generate.py
Normal file
82
kinit-api/scripts/crud_generate/utils/params_generate.py
Normal file
@ -0,0 +1,82 @@
|
||||
import inspect
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
from core.database import Base
|
||||
from .generate_base import GenerateBase
|
||||
|
||||
|
||||
class ParamsGenerate(GenerateBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Type[Base],
|
||||
zh_name: str,
|
||||
en_name: str,
|
||||
params_dir_path: Path,
|
||||
param_file_path: Path,
|
||||
param_class_name: str
|
||||
):
|
||||
"""
|
||||
初始化工作
|
||||
:param model: 提前定义好的 ORM 模型
|
||||
:param zh_name: 功能中文名称,主要用于描述、注释
|
||||
:param param_class_name:
|
||||
:param param_file_path:
|
||||
:param params_dir_path:
|
||||
:param en_name: 功能英文名称,主要用于 param、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class
|
||||
en_name 例子:
|
||||
如果 en_name 由多个单词组成那么请使用 _ 下划线拼接
|
||||
在命名文件名称时,会执行使用 _ 下划线名称
|
||||
在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase)
|
||||
在命名 url 时,会将下划线转换为 /
|
||||
"""
|
||||
self.model = model
|
||||
self.param_class_name = param_class_name
|
||||
self.zh_name = zh_name
|
||||
self.en_name = en_name
|
||||
# model 文件的地址
|
||||
self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__]))
|
||||
# model 文件 app 路径
|
||||
self.app_dir_path = self.model_file_path.parent.parent
|
||||
# params 目录地址
|
||||
self.params_dir_path = params_dir_path
|
||||
self.param_file_path = param_file_path
|
||||
|
||||
def write_generate_code(self):
|
||||
"""
|
||||
生成 params 文件,以及代码内容
|
||||
:return:
|
||||
"""
|
||||
param_init_file_path = self.params_dir_path / "__init__.py"
|
||||
self.param_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if self.param_file_path.exists():
|
||||
self.param_file_path.unlink()
|
||||
self.param_file_path.touch()
|
||||
param_init_file_path.touch()
|
||||
|
||||
code = self.generate_code()
|
||||
self.param_file_path.write_text(code, "utf-8")
|
||||
init_code = f"from .{self.en_name} import {self.param_class_name}"
|
||||
self.update_init_file(param_init_file_path, init_code)
|
||||
print(f"===========================param 代码创建完成=================================")
|
||||
|
||||
def generate_code(self) -> str:
|
||||
"""
|
||||
生成 schema 代码内容
|
||||
:return:
|
||||
"""
|
||||
code = self.generate_file_desc(self.param_file_path.name, "1.0", self.zh_name)
|
||||
|
||||
modules = {
|
||||
"fastapi": ['Depends'],
|
||||
"core.dependencies": ['Paging', "QueryParams"],
|
||||
}
|
||||
code += self.generate_modules_code(modules)
|
||||
|
||||
base_code = f"\n\nclass {self.param_class_name}(QueryParams):"
|
||||
base_code += f"\n\tdef __init__(self, params: Paging = Depends()):"
|
||||
base_code += f"\n\t\tsuper().__init__(params)"
|
||||
base_code += "\n"
|
||||
code += base_code
|
||||
return code.replace("\t", " ")
|
11
kinit-api/scripts/crud_generate/utils/schema.py
Normal file
11
kinit-api/scripts/crud_generate/utils/schema.py
Normal file
@ -0,0 +1,11 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SchemaField(BaseModel):
|
||||
name: str = Field(..., title="字段名称")
|
||||
field_type: str = Field(..., title="字段类型")
|
||||
nullable: bool = Field(False, title="是否可以为空")
|
||||
default: Any = Field(None, title="默认值")
|
||||
title: str | None = Field(None, title="字段描述")
|
||||
max_length: int | None = Field(None, title="最大长度")
|
143
kinit-api/scripts/crud_generate/utils/schema_generate.py
Normal file
143
kinit-api/scripts/crud_generate/utils/schema_generate.py
Normal file
@ -0,0 +1,143 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2024/1/12 17:28
|
||||
# @File : schema_generate.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : schema 代码生成
|
||||
|
||||
|
||||
import sys
|
||||
from typing import Type
|
||||
import inspect
|
||||
from sqlalchemy import inspect as model_inspect
|
||||
from pathlib import Path
|
||||
from core.database import Base
|
||||
from scripts.crud_generate.utils.schema import SchemaField
|
||||
from sqlalchemy.sql.schema import Column as ColumnType
|
||||
from scripts.crud_generate.utils.generate_base import GenerateBase
|
||||
|
||||
|
||||
class SchemaGenerate(GenerateBase):
|
||||
|
||||
BASE_FIELDS = ["id", "create_datetime", "update_datetime", "delete_datetime", "is_delete"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Type[Base],
|
||||
zh_name: str,
|
||||
en_name: str,
|
||||
schema_file_path: Path,
|
||||
schemas_dir_path: Path,
|
||||
base_class_name: str,
|
||||
schema_simple_out_class_name: str
|
||||
):
|
||||
"""
|
||||
初始化工作
|
||||
:param model: 提前定义好的 ORM 模型
|
||||
:param zh_name: 功能中文名称,主要用于描述、注释
|
||||
:param schema_file_path:
|
||||
:param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class
|
||||
en_name 例子:
|
||||
如果 en_name 由多个单词组成那么请使用 _ 下划线拼接
|
||||
在命名文件名称时,会执行使用 _ 下划线名称
|
||||
在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase)
|
||||
在命名 url 时,会将下划线转换为 /
|
||||
:param base_class_name:
|
||||
:param schema_simple_out_class_name:
|
||||
"""
|
||||
self.model = model
|
||||
self.base_class_name = base_class_name
|
||||
self.schema_simple_out_class_name = schema_simple_out_class_name
|
||||
self.zh_name = zh_name
|
||||
# model 文件的地址
|
||||
self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__]))
|
||||
# model 文件 app 路径
|
||||
self.app_dir_path = self.model_file_path.parent.parent
|
||||
self.en_name = en_name
|
||||
self.schema_file_path = schema_file_path
|
||||
self.schemas_dir_path = schemas_dir_path
|
||||
self.schema_init_file_path = self.schemas_dir_path / "__init__.py"
|
||||
|
||||
def write_generate_code(self):
|
||||
"""
|
||||
生成 schema 文件,以及代码内容
|
||||
:return:
|
||||
"""
|
||||
self.schema_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if self.schema_file_path.exists():
|
||||
# 存在则直接删除,重新创建写入
|
||||
self.schema_file_path.unlink()
|
||||
self.schema_file_path.touch()
|
||||
self.schema_init_file_path.touch()
|
||||
|
||||
code = self.generate_code()
|
||||
self.schema_file_path.write_text(code, "utf-8")
|
||||
|
||||
init_code = self.generate_init_code()
|
||||
self.update_init_file(self.schema_init_file_path, init_code)
|
||||
print(f"===========================schema 代码创建完成=================================")
|
||||
|
||||
def generate_init_code(self):
|
||||
"""
|
||||
生成 __init__ 文件导入代码
|
||||
todo 如果导入的类已经存在,则应该返回空
|
||||
:return:
|
||||
"""
|
||||
init_code = f"from .{self.en_name} import {self.base_class_name}, {self.schema_simple_out_class_name}"
|
||||
return init_code
|
||||
|
||||
def generate_code(self) -> str:
|
||||
"""
|
||||
生成 schema 代码内容
|
||||
:return:
|
||||
"""
|
||||
fields = []
|
||||
mapper = model_inspect(self.model)
|
||||
for attr_name, column_property in mapper.column_attrs.items():
|
||||
if attr_name in self.BASE_FIELDS:
|
||||
continue
|
||||
# 假设它是单列属性
|
||||
column: ColumnType = column_property.columns[0]
|
||||
item = SchemaField(
|
||||
name=attr_name,
|
||||
field_type=column.type.python_type.__name__,
|
||||
nullable=column.nullable,
|
||||
default=column.default.__dict__.get("arg", None) if column.default else None,
|
||||
title=column.comment,
|
||||
max_length=column.type.__dict__.get("length", None)
|
||||
)
|
||||
fields.append(item)
|
||||
|
||||
code = self.generate_file_desc(self.schema_file_path.name, "1.0", "pydantic 模型,用于数据库序列化操作")
|
||||
|
||||
modules = {
|
||||
"pydantic": ['BaseModel', "Field", "ConfigDict"],
|
||||
"core.data_types": ['DatetimeStr'],
|
||||
}
|
||||
code += self.generate_modules_code(modules)
|
||||
|
||||
base_schema_code = f"\n\nclass {self.base_class_name}(BaseModel):"
|
||||
for item in fields:
|
||||
field = f"\n\t{item.name}: {item.field_type} {'| None ' if item.nullable else ''}"
|
||||
default = None
|
||||
if item.default is not None:
|
||||
if item.field_type == "str":
|
||||
default = f"\"{item.default}\""
|
||||
else:
|
||||
default = item.default
|
||||
elif default is None and not item.nullable:
|
||||
default = "..."
|
||||
|
||||
field += f"= Field({default}, title=\"{item.title}\")"
|
||||
base_schema_code += field
|
||||
base_schema_code += "\n"
|
||||
code += base_schema_code
|
||||
|
||||
base_out_schema_code = f"\n\nclass {self.schema_simple_out_class_name}({self.base_class_name}):"
|
||||
base_out_schema_code += "\n\tmodel_config = ConfigDict(from_attributes=True)\n"
|
||||
base_out_schema_code += "\n\tid: int = Field(..., title=\"编号\")"
|
||||
base_out_schema_code += "\n\tcreate_datetime: DatetimeStr = Field(..., title=\"创建时间\")"
|
||||
base_out_schema_code += "\n\tupdate_datetime: DatetimeStr = Field(..., title=\"更新时间\")"
|
||||
base_out_schema_code += "\n"
|
||||
code += base_out_schema_code
|
||||
return code.replace("\t", " ")
|
143
kinit-api/scripts/crud_generate/utils/view_generate.py
Normal file
143
kinit-api/scripts/crud_generate/utils/view_generate.py
Normal file
@ -0,0 +1,143 @@
|
||||
import inspect
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
from core.database import Base
|
||||
from .generate_base import GenerateBase
|
||||
|
||||
|
||||
class ViewGenerate(GenerateBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Type[Base],
|
||||
zh_name: str,
|
||||
en_name: str,
|
||||
schema_class_name: str,
|
||||
schema_simple_out_class_name: str,
|
||||
dal_class_name: str,
|
||||
param_class_name: str
|
||||
):
|
||||
"""
|
||||
初始化工作
|
||||
:param model: 提前定义好的 ORM 模型
|
||||
:param zh_name: 功能中文名称,主要用于描述、注释
|
||||
:param schema_class_name:
|
||||
:param schema_simple_out_class_name:
|
||||
:param dal_class_name:
|
||||
:param param_class_name:
|
||||
:param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class
|
||||
en_name 例子:
|
||||
如果 en_name 由多个单词组成那么请使用 _ 下划线拼接
|
||||
在命名文件名称时,会执行使用 _ 下划线名称
|
||||
在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase)
|
||||
在命名 url 时,会将下划线转换为 /
|
||||
"""
|
||||
self.model = model
|
||||
self.schema_class_name = schema_class_name
|
||||
self.schema_simple_out_class_name = schema_simple_out_class_name
|
||||
self.dal_class_name = dal_class_name
|
||||
self.param_class_name = param_class_name
|
||||
self.zh_name = zh_name
|
||||
self.en_name = en_name
|
||||
# model 文件的地址
|
||||
self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__]))
|
||||
# model 文件 app 路径
|
||||
self.app_dir_path = self.model_file_path.parent.parent
|
||||
# view 文件地址
|
||||
self.view_file_path = self.app_dir_path / "views.py"
|
||||
|
||||
def write_generate_code(self):
|
||||
"""
|
||||
生成 view 文件,以及代码内容
|
||||
:return:
|
||||
"""
|
||||
if self.view_file_path.exists():
|
||||
codes = self.file_code_split_module(self.view_file_path)
|
||||
if codes:
|
||||
print(f"==========view 文件已存在并已有代码内容,正在追加新代码============")
|
||||
if not codes[0]:
|
||||
# 无文件注释则添加文件注释
|
||||
codes[0] = self.generate_file_desc(self.view_file_path.name, "1.0", "视图层")
|
||||
codes[1] = self.merge_dictionaries(codes[1], self.get_base_module_config())
|
||||
codes[2] += self.get_base_code_content()
|
||||
code = ''
|
||||
code += codes[0]
|
||||
code += self.generate_modules_code(codes[1])
|
||||
if "app = APIRouter()" not in codes[2]:
|
||||
code += "\n\napp = APIRouter()"
|
||||
code += codes[2]
|
||||
self.view_file_path.write_text(code, "utf-8")
|
||||
print(f"=================view 代码已创建完成=====================")
|
||||
return
|
||||
else:
|
||||
self.view_file_path.touch()
|
||||
code = self.generate_code()
|
||||
self.view_file_path.write_text(code, encoding="utf-8")
|
||||
print(f"===============view 代码创建完成==================")
|
||||
|
||||
def generate_code(self) -> str:
|
||||
"""
|
||||
生成代码
|
||||
:return:
|
||||
"""
|
||||
code = self.generate_file_desc(self.view_file_path.name, "1.0", "路由,视图文件")
|
||||
code += self.generate_modules_code(self.get_base_module_config())
|
||||
code += "\n\napp = APIRouter()"
|
||||
code += self.get_base_code_content()
|
||||
|
||||
return code.replace("\t", " ")
|
||||
|
||||
@staticmethod
|
||||
def get_base_module_config():
|
||||
"""
|
||||
获取基础模块导入配置
|
||||
:return:
|
||||
"""
|
||||
modules = {
|
||||
"sqlalchemy.ext.asyncio": ['AsyncSession'],
|
||||
"fastapi": ["APIRouter", "Depends"],
|
||||
".": ["models", "schemas", "crud", "params"],
|
||||
"core.dependencies": ["IdList"],
|
||||
"apps.vadmin.auth.utils.current": ["AllUserAuth"],
|
||||
"utils.response": ["SuccessResponse"],
|
||||
"apps.vadmin.auth.utils.validation.auth": ["Auth"],
|
||||
"core.database": ["db_getter"],
|
||||
}
|
||||
return modules
|
||||
|
||||
def get_base_code_content(self):
|
||||
"""
|
||||
获取基础代码内容
|
||||
:return:
|
||||
"""
|
||||
base_code = "\n\n\n###########################################################"
|
||||
base_code += f"\n# {self.zh_name}"
|
||||
base_code += "\n###########################################################"
|
||||
|
||||
router = self.en_name.replace("_", "/")
|
||||
|
||||
base_code += f"\n@app.get(\"/{router}\", summary=\"获取{self.zh_name}列表\", tags=[\"{self.zh_name}\"])"
|
||||
base_code += f"\nasync def get_{self.en_name}_list(p: params.{self.param_class_name} = Depends(), auth: Auth = Depends(AllUserAuth())):"
|
||||
base_code += f"\n\tdatas, count = await crud.{self.dal_class_name}(auth.db).get_datas(**p.dict(), v_return_count=True)"
|
||||
base_code += f"\n\treturn SuccessResponse(datas, count=count)\n"
|
||||
|
||||
base_code += f"\n\n@app.post(\"/{router}\", summary=\"创建{self.zh_name}\", tags=[\"{self.zh_name}\"])"
|
||||
base_code += f"\nasync def create_{self.en_name}(data: schemas.{self.schema_class_name}, auth: Auth = Depends(AllUserAuth())):"
|
||||
base_code += f"\n\treturn SuccessResponse(await crud.{self.dal_class_name}(auth.db).create_data(data=data))\n"
|
||||
|
||||
base_code += f"\n\n@app.delete(\"/{router}\", summary=\"删除{self.zh_name}\", description=\"硬删除\", tags=[\"{self.zh_name}\"])"
|
||||
base_code += f"\nasync def delete_{self.en_name}_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):"
|
||||
base_code += f"\n\tawait crud.{self.dal_class_name}(auth.db).delete_datas(ids=ids.ids, v_soft=False)"
|
||||
base_code += f"\n\treturn SuccessResponse(\"删除成功\")\n"
|
||||
|
||||
base_code += f"\n\n@app.put(\"/{router}" + "/{data_id}\"" + f", summary=\"更新{self.zh_name}\", tags=[\"{self.zh_name}\"])"
|
||||
base_code += f"\nasync def put_{self.en_name}(data_id: int, data: schemas.{self.schema_class_name}, auth: Auth = Depends(AllUserAuth())):"
|
||||
base_code += f"\n\treturn SuccessResponse(await crud.{self.dal_class_name}(auth.db).put_data(data_id, data))\n"
|
||||
|
||||
base_code += f"\n\n@app.get(\"/{router}" + "/{data_id}\"" + f", summary=\"获取{self.zh_name}信息\", tags=[\"{self.zh_name}\"])"
|
||||
base_code += f"\nasync def get_{self.en_name}(data_id: int, db: AsyncSession = Depends(db_getter)):"
|
||||
base_code += f"\n\tschema = schemas.{self.schema_simple_out_class_name}"
|
||||
base_code += f"\n\treturn SuccessResponse(await crud.{self.dal_class_name}(db).get_data(data_id, v_schema=schema))\n"
|
||||
base_code += "\n"
|
||||
return base_code.replace("\t", " ")
|
Loading…
x
Reference in New Issue
Block a user