diff --git a/kinit-api/scripts/crud_generate/__init__.py b/kinit-api/scripts/crud_generate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kinit-api/scripts/crud_generate/main.py b/kinit-api/scripts/crud_generate/main.py new file mode 100644 index 0000000..82e0043 --- /dev/null +++ b/kinit-api/scripts/crud_generate/main.py @@ -0,0 +1,166 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# @version : 1.0 +# @Create Time : 2022/12/9 15:27 +# @File : main.py +# @IDE : PyCharm +# @desc : 简要说明 + +import os.path +import sys +from typing import Type +from application.settings import BASE_DIR +import inspect +from pathlib import Path +from core.database import Base +from scripts.crud_generate.utils.generate_base import GenerateBase +from scripts.crud_generate.utils.schema_generate import SchemaGenerate +from scripts.crud_generate.utils.params_generate import ParamsGenerate +from scripts.crud_generate.utils.dal_generate import DalGenerate +from scripts.crud_generate.utils.view_generate import ViewGenerate + + +class CrudGenerate(GenerateBase): + + APPS_ROOT = os.path.join(BASE_DIR, "apps") + SCRIPT_DIR = os.path.join(BASE_DIR, 'scripts', 'crud_generate') + + def __init__(self, model: Type[Base], zh_name: str, en_name: str = None): + """ + 初始化工作 + :param model: 提前定义好的 ORM 模型 + :param zh_name: 功能中文名称,主要用于描述、注释 + :param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class + en_name 例子: + 如果 en_name 由多个单词组成那么请使用 _ 下划线拼接 + 在命名文件名称时,会执行使用 _ 下划线名称 + 在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase) + 在命名 url 时,会将下划线转换为 / + """ + self.model = model + self.zh_name = zh_name + # model 文件的地址 + self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__])) + # model 文件 app 路径 + self.app_dir_path = self.model_file_path.parent.parent + # schemas 目录地址 + self.schemas_dir_path = self.app_dir_path / "schemas" + # params 目录地址 + self.params_dir_path = self.app_dir_path / "params" + # crud 文件地址 + self.crud_file_path = self.app_dir_path / "crud.py" + # view 文件地址 + self.view_file_path = self.app_dir_path / "views.py" + + if en_name: + self.en_name = en_name + else: + self.en_name = self.model.__name__ + + self.schema_file_path = self.schemas_dir_path / f"{self.en_name}.py" + self.param_file_path = self.params_dir_path / f"{self.en_name}.py" + + self.base_class_name = self.snake_to_camel(self.en_name) + self.schema_simple_out_class_name = f"{self.base_class_name}SimpleOut" + self.dal_class_name = f"{self.base_class_name}Dal" + self.param_class_name = f"{self.base_class_name}Params" + + def generate_codes(self): + """ + 生成代码, 不做实际操作,只是将代码打印出来 + :return: + """ + print(f"==========================={self.schema_file_path} 代码内容=================================") + schema = SchemaGenerate( + self.model, + self.zh_name, + self.en_name, + self.schema_file_path, + self.schemas_dir_path, + self.base_class_name, + self.schema_simple_out_class_name + ) + print(schema.generate_code()) + + print(f"==========================={self.dal_class_name} 代码内容=================================") + dal = DalGenerate( + self.model, + self.zh_name, + self.en_name, + self.dal_class_name, + self.schema_simple_out_class_name + ) + print(dal.generate_code()) + + print(f"==========================={self.param_file_path} 代码内容=================================") + params = ParamsGenerate( + self.model, + self.zh_name, + self.en_name, + self.params_dir_path, + self.param_file_path, + self.param_class_name + ) + print(params.generate_code()) + + print(f"==========================={self.view_file_path} 代码内容=================================") + view = ViewGenerate( + self.model, + self.zh_name, + self.en_name, + self.base_class_name, + self.schema_simple_out_class_name, + self.dal_class_name, + self.param_class_name + ) + print(view.generate_code()) + + def main(self): + """ + 开始生成 crud 代码,并直接写入到项目中,目前还未实现 + 1. 生成 schemas 代码 + 2. 生成 dal 代码 + 3. 生成 params 代码 + 4. 生成 views 代码 + :return: + """ + schema = SchemaGenerate( + self.model, + self.zh_name, + self.en_name, + self.schema_file_path, + self.schemas_dir_path, + self.base_class_name, + self.schema_simple_out_class_name + ) + schema.write_generate_code() + + dal = DalGenerate( + self.model, + self.zh_name, + self.en_name, + self.dal_class_name, + self.schema_simple_out_class_name + ) + dal.write_generate_code() + + params = ParamsGenerate( + self.model, + self.zh_name, + self.en_name, + self.params_dir_path, + self.param_file_path, + self.param_class_name + ) + params.write_generate_code() + + view = ViewGenerate( + self.model, + self.zh_name, + self.en_name, + self.base_class_name, + self.schema_simple_out_class_name, + self.dal_class_name, + self.param_class_name + ) + view.write_generate_code() diff --git a/kinit-api/scripts/crud_generate/utils/__init__.py b/kinit-api/scripts/crud_generate/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kinit-api/scripts/crud_generate/utils/dal_generate.py b/kinit-api/scripts/crud_generate/utils/dal_generate.py new file mode 100644 index 0000000..d2ce5ae --- /dev/null +++ b/kinit-api/scripts/crud_generate/utils/dal_generate.py @@ -0,0 +1,106 @@ +import inspect +import sys +from pathlib import Path +from typing import Type +from core.database import Base +from .generate_base import GenerateBase + + +class DalGenerate(GenerateBase): + + def __init__( + self, + model: Type[Base], + zh_name: str, + en_name: str, + dal_class_name: str, + schema_simple_out_class_name: str + ): + """ + 初始化工作 + :param model: 提前定义好的 ORM 模型 + :param zh_name: 功能中文名称,主要用于描述、注释 + :param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class + en_name 例子: + 如果 en_name 由多个单词组成那么请使用 _ 下划线拼接 + 在命名文件名称时,会执行使用 _ 下划线名称 + 在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase) + 在命名 url 时,会将下划线转换为 / + :param dal_class_name: + :param schema_simple_out_class_name: + """ + self.model = model + self.dal_class_name = dal_class_name + self.schema_simple_out_class_name = schema_simple_out_class_name + self.zh_name = zh_name + self.en_name = en_name + # model 文件的地址 + self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__])) + # model 文件 app 路径 + self.app_dir_path = self.model_file_path.parent.parent + # crud 文件地址 + self.crud_file_path = self.app_dir_path / "crud.py" + + def write_generate_code(self): + """ + 生成 crud 文件,以及代码内容 + :return: + """ + if self.crud_file_path.exists(): + codes = self.file_code_split_module(self.crud_file_path) + if codes: + print(f"==========dal 文件已存在并已有代码内容,正在追加新代码============") + if not codes[0]: + # 无文件注释则添加文件注释 + codes[0] = self.generate_file_desc(self.crud_file_path.name, "1.0", "数据访问层") + codes[1] = self.merge_dictionaries(codes[1], self.get_base_module_config()) + codes[2] += self.get_base_code_content() + code = '' + code += codes[0] + code += self.generate_modules_code(codes[1]) + code += codes[2] + self.crud_file_path.write_text(code, "utf-8") + print(f"=================dal 代码已创建完成=======================") + return + self.crud_file_path.touch() + code = self.generate_code() + self.crud_file_path.write_text(code, "utf-8") + print(f"===========================dal 代码创建完成=================================") + + def generate_code(self): + """ + 代码生成 + :return: + """ + code = self.generate_file_desc(self.crud_file_path.name, "1.0", "数据访问层") + code += self.generate_modules_code(self.get_base_module_config()) + code += self.get_base_code_content() + return code + + @staticmethod + def get_base_module_config(): + """ + 获取基础模块导入配置 + :return: + """ + modules = { + "sqlalchemy.ext.asyncio": ['AsyncSession'], + "core.crud": ["DalBase"], + ".": ["models", "schemas"], + } + return modules + + def get_base_code_content(self): + """ + 获取基础代码内容 + :return: + """ + base_code = f"\n\nclass {self.dal_class_name}(DalBase):\n" + base_code += "\n\tdef __init__(self, db: AsyncSession):" + base_code += f"\n\t\tsuper({self.dal_class_name}, self).__init__()" + base_code += f"\n\t\tself.db = db" + base_code += f"\n\t\tself.model = models.{self.model.__name__}" + base_code += f"\n\t\tself.schema = schemas.{self.schema_simple_out_class_name}" + base_code += "\n" + return base_code.replace("\t", " ") + diff --git a/kinit-api/scripts/crud_generate/utils/generate_base.py b/kinit-api/scripts/crud_generate/utils/generate_base.py new file mode 100644 index 0000000..c19ae5b --- /dev/null +++ b/kinit-api/scripts/crud_generate/utils/generate_base.py @@ -0,0 +1,185 @@ +import datetime +import re +from pathlib import Path + + +class GenerateBase: + + @staticmethod + def camel_to_snake(name: str) -> str: + """ + 将大驼峰命名(CamelCase)转换为下划线命名(snake_case) + 在大写字母前添加一个空格,然后将字符串分割并用下划线拼接 + :param name: 大驼峰命名(CamelCase) + :return: + """ + s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + + @staticmethod + def snake_to_camel(name: str) -> str: + """ + 将下划线命名(snake_case)转换为大驼峰命名(CamelCase) + 根据下划线分割,然后将字符串转为第一个字符大写后拼接 + :param name: 下划线命名(snake_case) + :return: + """ + # 按下划线分割字符串 + words = name.split('_') + # 将每个单词的首字母大写,然后拼接 + return ''.join(word.capitalize() for word in words) + + @staticmethod + def generate_file_desc(filename: str, version: str = '1.0', desc: str = '') -> str: + """ + 生成文件注释 + :param filename: + :param version: + :param desc: + :return: + """ + code = '#!/usr/bin/python\n# -*- coding: utf-8 -*-' + code += f"\n# @version : {version}" + code += f"\n# @Create Time : {datetime.datetime.now().strftime('%Y/%m/%d %H:%M')}" + code += f"\n# @File : {filename}" + code += f"\n# @IDE : PyCharm" + code += f"\n# @desc : {desc}" + code += f"\n" + return code + + @staticmethod + def generate_modules_code(modules: dict[str, list]) -> str: + """ + 生成模块导入代码 + :param modules: 导入得模块 + :return: + """ + code = "\n" + args = modules.pop("args", []) + for k, v in modules.items(): + code += f"from {k} import {', '.join(v)}\n" + if args: + code += f"import {', '.join(args)}\n" + return code + + @staticmethod + def update_init_file(init_file: Path, code: str): + """ + __init__ 文件添加导入内容 + :param init_file: + :param code: + :return: + """ + content = init_file.read_text() + if content and code in content: + return + if content: + if content.endswith("\n"): + with init_file.open("a+", encoding="utf-8") as f: + f.write(f"{code}\n") + else: + with init_file.open("a+", encoding="utf-8") as f: + f.write(f"\n{code}\n") + else: + init_file.write_text(f"{code}\n", encoding="utf-8") + + @staticmethod + def module_code_to_dict(code: str) -> dict: + """ + 将 from import 语句代码转为 dict 格式 + :param code: + :return: + """ + # 分解代码为单行 + lines = code.strip().split('\n') + + # 初始化字典 + modules = {} + + # 遍历每行代码 + for line in lines: + # 处理 'from ... import ...' 类型的导入 + if line.startswith('from'): + parts = line.split(' import ') + module = parts[0][5:] # 移除 'from ' 并获取模块路径 + imports = parts[1].split(',') # 使用逗号分割导入项 + imports = [item.strip() for item in imports] # 移除多余空格 + if module in modules: + modules[module].extend(imports) + else: + modules[module] = imports + + # 处理 'import ...' 类型的导入 + elif line.startswith('import'): + imports = line.split('import ')[1] + # 分割多个导入项 + imports = imports.split(', ') + for imp in imports: + # 处理直接导入的模块 + modules.setdefault('args', []).append(imp) + return modules + + @classmethod + def file_code_split_module(cls, file: Path) -> list: + """ + 文件代码内容拆分,分为以下三部分 + 1. 文件开头的注释。 + 2. 全局层面的from import语句。该代码格式会被转换为 dict 格式 + 3. 其他代码内容。 + :param file: + :return: + """ + content = file.read_text(encoding="utf-8") + if not content: + return [] + lines = content.split('\n') + part1 = [] # 文件开头注释 + part2 = [] # from import 语句 + part3 = [] # 其他代码内容 + + # 标记是否已超过注释部分 + past_comments = False + + for line in lines: + # 检查是否为注释行 + if line.startswith("#") and not past_comments: + part1.append(line) + else: + # 标记已超过注释部分 + past_comments = True + # 检查是否为 from import 语句 + if line.startswith("from ") or line.startswith("import "): + part2.append(line) + else: + part3.append(line) + + part2 = cls.module_code_to_dict('\n'.join(part2)) + + return ['\n'.join(part1), part2, '\n'.join(part3)] + + @staticmethod + def merge_dictionaries(dict1, dict2): + """ + 合并两个键为字符串、值为列表的字典 + :param dict1: + :param dict2: + :return: + """ + # 初始化结果字典 + merged_dict = {} + + # 合并两个字典中的键值对 + for key in set(dict1) | set(dict2): # 获取两个字典的键的并集 + merged_dict[key] = list(set(dict1.get(key, []) + dict2.get(key, []))) + + return merged_dict + + +if __name__ == '__main__': + _modules = { + "sqlalchemy.ext.asyncio": ['AsyncSession'], + "core.crud": ["DalBase"], + ".": ["models", "schemas"], + "args": ["test", "test1"] + } + print(GenerateBase.generate_modules_code(_modules)) diff --git a/kinit-api/scripts/crud_generate/utils/params_generate.py b/kinit-api/scripts/crud_generate/utils/params_generate.py new file mode 100644 index 0000000..d87510c --- /dev/null +++ b/kinit-api/scripts/crud_generate/utils/params_generate.py @@ -0,0 +1,82 @@ +import inspect +import sys +from pathlib import Path +from typing import Type +from core.database import Base +from .generate_base import GenerateBase + + +class ParamsGenerate(GenerateBase): + + def __init__( + self, + model: Type[Base], + zh_name: str, + en_name: str, + params_dir_path: Path, + param_file_path: Path, + param_class_name: str + ): + """ + 初始化工作 + :param model: 提前定义好的 ORM 模型 + :param zh_name: 功能中文名称,主要用于描述、注释 + :param param_class_name: + :param param_file_path: + :param params_dir_path: + :param en_name: 功能英文名称,主要用于 param、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class + en_name 例子: + 如果 en_name 由多个单词组成那么请使用 _ 下划线拼接 + 在命名文件名称时,会执行使用 _ 下划线名称 + 在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase) + 在命名 url 时,会将下划线转换为 / + """ + self.model = model + self.param_class_name = param_class_name + self.zh_name = zh_name + self.en_name = en_name + # model 文件的地址 + self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__])) + # model 文件 app 路径 + self.app_dir_path = self.model_file_path.parent.parent + # params 目录地址 + self.params_dir_path = params_dir_path + self.param_file_path = param_file_path + + def write_generate_code(self): + """ + 生成 params 文件,以及代码内容 + :return: + """ + param_init_file_path = self.params_dir_path / "__init__.py" + self.param_file_path.parent.mkdir(parents=True, exist_ok=True) + if self.param_file_path.exists(): + self.param_file_path.unlink() + self.param_file_path.touch() + param_init_file_path.touch() + + code = self.generate_code() + self.param_file_path.write_text(code, "utf-8") + init_code = f"from .{self.en_name} import {self.param_class_name}" + self.update_init_file(param_init_file_path, init_code) + print(f"===========================param 代码创建完成=================================") + + def generate_code(self) -> str: + """ + 生成 schema 代码内容 + :return: + """ + code = self.generate_file_desc(self.param_file_path.name, "1.0", self.zh_name) + + modules = { + "fastapi": ['Depends'], + "core.dependencies": ['Paging', "QueryParams"], + } + code += self.generate_modules_code(modules) + + base_code = f"\n\nclass {self.param_class_name}(QueryParams):" + base_code += f"\n\tdef __init__(self, params: Paging = Depends()):" + base_code += f"\n\t\tsuper().__init__(params)" + base_code += "\n" + code += base_code + return code.replace("\t", " ") diff --git a/kinit-api/scripts/crud_generate/utils/schema.py b/kinit-api/scripts/crud_generate/utils/schema.py new file mode 100644 index 0000000..65ccfa6 --- /dev/null +++ b/kinit-api/scripts/crud_generate/utils/schema.py @@ -0,0 +1,11 @@ +from typing import Any +from pydantic import BaseModel, Field + + +class SchemaField(BaseModel): + name: str = Field(..., title="字段名称") + field_type: str = Field(..., title="字段类型") + nullable: bool = Field(False, title="是否可以为空") + default: Any = Field(None, title="默认值") + title: str | None = Field(None, title="字段描述") + max_length: int | None = Field(None, title="最大长度") diff --git a/kinit-api/scripts/crud_generate/utils/schema_generate.py b/kinit-api/scripts/crud_generate/utils/schema_generate.py new file mode 100644 index 0000000..cdb2a15 --- /dev/null +++ b/kinit-api/scripts/crud_generate/utils/schema_generate.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +# @version : 1.0 +# @Create Time : 2024/1/12 17:28 +# @File : schema_generate.py +# @IDE : PyCharm +# @desc : schema 代码生成 + + +import sys +from typing import Type +import inspect +from sqlalchemy import inspect as model_inspect +from pathlib import Path +from core.database import Base +from scripts.crud_generate.utils.schema import SchemaField +from sqlalchemy.sql.schema import Column as ColumnType +from scripts.crud_generate.utils.generate_base import GenerateBase + + +class SchemaGenerate(GenerateBase): + + BASE_FIELDS = ["id", "create_datetime", "update_datetime", "delete_datetime", "is_delete"] + + def __init__( + self, + model: Type[Base], + zh_name: str, + en_name: str, + schema_file_path: Path, + schemas_dir_path: Path, + base_class_name: str, + schema_simple_out_class_name: str + ): + """ + 初始化工作 + :param model: 提前定义好的 ORM 模型 + :param zh_name: 功能中文名称,主要用于描述、注释 + :param schema_file_path: + :param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class + en_name 例子: + 如果 en_name 由多个单词组成那么请使用 _ 下划线拼接 + 在命名文件名称时,会执行使用 _ 下划线名称 + 在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase) + 在命名 url 时,会将下划线转换为 / + :param base_class_name: + :param schema_simple_out_class_name: + """ + self.model = model + self.base_class_name = base_class_name + self.schema_simple_out_class_name = schema_simple_out_class_name + self.zh_name = zh_name + # model 文件的地址 + self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__])) + # model 文件 app 路径 + self.app_dir_path = self.model_file_path.parent.parent + self.en_name = en_name + self.schema_file_path = schema_file_path + self.schemas_dir_path = schemas_dir_path + self.schema_init_file_path = self.schemas_dir_path / "__init__.py" + + def write_generate_code(self): + """ + 生成 schema 文件,以及代码内容 + :return: + """ + self.schema_file_path.parent.mkdir(parents=True, exist_ok=True) + if self.schema_file_path.exists(): + # 存在则直接删除,重新创建写入 + self.schema_file_path.unlink() + self.schema_file_path.touch() + self.schema_init_file_path.touch() + + code = self.generate_code() + self.schema_file_path.write_text(code, "utf-8") + + init_code = self.generate_init_code() + self.update_init_file(self.schema_init_file_path, init_code) + print(f"===========================schema 代码创建完成=================================") + + def generate_init_code(self): + """ + 生成 __init__ 文件导入代码 + todo 如果导入的类已经存在,则应该返回空 + :return: + """ + init_code = f"from .{self.en_name} import {self.base_class_name}, {self.schema_simple_out_class_name}" + return init_code + + def generate_code(self) -> str: + """ + 生成 schema 代码内容 + :return: + """ + fields = [] + mapper = model_inspect(self.model) + for attr_name, column_property in mapper.column_attrs.items(): + if attr_name in self.BASE_FIELDS: + continue + # 假设它是单列属性 + column: ColumnType = column_property.columns[0] + item = SchemaField( + name=attr_name, + field_type=column.type.python_type.__name__, + nullable=column.nullable, + default=column.default.__dict__.get("arg", None) if column.default else None, + title=column.comment, + max_length=column.type.__dict__.get("length", None) + ) + fields.append(item) + + code = self.generate_file_desc(self.schema_file_path.name, "1.0", "pydantic 模型,用于数据库序列化操作") + + modules = { + "pydantic": ['BaseModel', "Field", "ConfigDict"], + "core.data_types": ['DatetimeStr'], + } + code += self.generate_modules_code(modules) + + base_schema_code = f"\n\nclass {self.base_class_name}(BaseModel):" + for item in fields: + field = f"\n\t{item.name}: {item.field_type} {'| None ' if item.nullable else ''}" + default = None + if item.default is not None: + if item.field_type == "str": + default = f"\"{item.default}\"" + else: + default = item.default + elif default is None and not item.nullable: + default = "..." + + field += f"= Field({default}, title=\"{item.title}\")" + base_schema_code += field + base_schema_code += "\n" + code += base_schema_code + + base_out_schema_code = f"\n\nclass {self.schema_simple_out_class_name}({self.base_class_name}):" + base_out_schema_code += "\n\tmodel_config = ConfigDict(from_attributes=True)\n" + base_out_schema_code += "\n\tid: int = Field(..., title=\"编号\")" + base_out_schema_code += "\n\tcreate_datetime: DatetimeStr = Field(..., title=\"创建时间\")" + base_out_schema_code += "\n\tupdate_datetime: DatetimeStr = Field(..., title=\"更新时间\")" + base_out_schema_code += "\n" + code += base_out_schema_code + return code.replace("\t", " ") diff --git a/kinit-api/scripts/crud_generate/utils/view_generate.py b/kinit-api/scripts/crud_generate/utils/view_generate.py new file mode 100644 index 0000000..6a07103 --- /dev/null +++ b/kinit-api/scripts/crud_generate/utils/view_generate.py @@ -0,0 +1,143 @@ +import inspect +import sys +from pathlib import Path +from typing import Type +from core.database import Base +from .generate_base import GenerateBase + + +class ViewGenerate(GenerateBase): + + def __init__( + self, + model: Type[Base], + zh_name: str, + en_name: str, + schema_class_name: str, + schema_simple_out_class_name: str, + dal_class_name: str, + param_class_name: str + ): + """ + 初始化工作 + :param model: 提前定义好的 ORM 模型 + :param zh_name: 功能中文名称,主要用于描述、注释 + :param schema_class_name: + :param schema_simple_out_class_name: + :param dal_class_name: + :param param_class_name: + :param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class + en_name 例子: + 如果 en_name 由多个单词组成那么请使用 _ 下划线拼接 + 在命名文件名称时,会执行使用 _ 下划线名称 + 在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase) + 在命名 url 时,会将下划线转换为 / + """ + self.model = model + self.schema_class_name = schema_class_name + self.schema_simple_out_class_name = schema_simple_out_class_name + self.dal_class_name = dal_class_name + self.param_class_name = param_class_name + self.zh_name = zh_name + self.en_name = en_name + # model 文件的地址 + self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__])) + # model 文件 app 路径 + self.app_dir_path = self.model_file_path.parent.parent + # view 文件地址 + self.view_file_path = self.app_dir_path / "views.py" + + def write_generate_code(self): + """ + 生成 view 文件,以及代码内容 + :return: + """ + if self.view_file_path.exists(): + codes = self.file_code_split_module(self.view_file_path) + if codes: + print(f"==========view 文件已存在并已有代码内容,正在追加新代码============") + if not codes[0]: + # 无文件注释则添加文件注释 + codes[0] = self.generate_file_desc(self.view_file_path.name, "1.0", "视图层") + codes[1] = self.merge_dictionaries(codes[1], self.get_base_module_config()) + codes[2] += self.get_base_code_content() + code = '' + code += codes[0] + code += self.generate_modules_code(codes[1]) + if "app = APIRouter()" not in codes[2]: + code += "\n\napp = APIRouter()" + code += codes[2] + self.view_file_path.write_text(code, "utf-8") + print(f"=================view 代码已创建完成=====================") + return + else: + self.view_file_path.touch() + code = self.generate_code() + self.view_file_path.write_text(code, encoding="utf-8") + print(f"===============view 代码创建完成==================") + + def generate_code(self) -> str: + """ + 生成代码 + :return: + """ + code = self.generate_file_desc(self.view_file_path.name, "1.0", "路由,视图文件") + code += self.generate_modules_code(self.get_base_module_config()) + code += "\n\napp = APIRouter()" + code += self.get_base_code_content() + + return code.replace("\t", " ") + + @staticmethod + def get_base_module_config(): + """ + 获取基础模块导入配置 + :return: + """ + modules = { + "sqlalchemy.ext.asyncio": ['AsyncSession'], + "fastapi": ["APIRouter", "Depends"], + ".": ["models", "schemas", "crud", "params"], + "core.dependencies": ["IdList"], + "apps.vadmin.auth.utils.current": ["AllUserAuth"], + "utils.response": ["SuccessResponse"], + "apps.vadmin.auth.utils.validation.auth": ["Auth"], + "core.database": ["db_getter"], + } + return modules + + def get_base_code_content(self): + """ + 获取基础代码内容 + :return: + """ + base_code = "\n\n\n###########################################################" + base_code += f"\n# {self.zh_name}" + base_code += "\n###########################################################" + + router = self.en_name.replace("_", "/") + + base_code += f"\n@app.get(\"/{router}\", summary=\"获取{self.zh_name}列表\", tags=[\"{self.zh_name}\"])" + base_code += f"\nasync def get_{self.en_name}_list(p: params.{self.param_class_name} = Depends(), auth: Auth = Depends(AllUserAuth())):" + base_code += f"\n\tdatas, count = await crud.{self.dal_class_name}(auth.db).get_datas(**p.dict(), v_return_count=True)" + base_code += f"\n\treturn SuccessResponse(datas, count=count)\n" + + base_code += f"\n\n@app.post(\"/{router}\", summary=\"创建{self.zh_name}\", tags=[\"{self.zh_name}\"])" + base_code += f"\nasync def create_{self.en_name}(data: schemas.{self.schema_class_name}, auth: Auth = Depends(AllUserAuth())):" + base_code += f"\n\treturn SuccessResponse(await crud.{self.dal_class_name}(auth.db).create_data(data=data))\n" + + base_code += f"\n\n@app.delete(\"/{router}\", summary=\"删除{self.zh_name}\", description=\"硬删除\", tags=[\"{self.zh_name}\"])" + base_code += f"\nasync def delete_{self.en_name}_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):" + base_code += f"\n\tawait crud.{self.dal_class_name}(auth.db).delete_datas(ids=ids.ids, v_soft=False)" + base_code += f"\n\treturn SuccessResponse(\"删除成功\")\n" + + base_code += f"\n\n@app.put(\"/{router}" + "/{data_id}\"" + f", summary=\"更新{self.zh_name}\", tags=[\"{self.zh_name}\"])" + base_code += f"\nasync def put_{self.en_name}(data_id: int, data: schemas.{self.schema_class_name}, auth: Auth = Depends(AllUserAuth())):" + base_code += f"\n\treturn SuccessResponse(await crud.{self.dal_class_name}(auth.db).put_data(data_id, data))\n" + + base_code += f"\n\n@app.get(\"/{router}" + "/{data_id}\"" + f", summary=\"获取{self.zh_name}信息\", tags=[\"{self.zh_name}\"])" + base_code += f"\nasync def get_{self.en_name}(data_id: int, db: AsyncSession = Depends(db_getter)):" + base_code += f"\n\tschema = schemas.{self.schema_simple_out_class_name}" + base_code += f"\n\treturn SuccessResponse(await crud.{self.dal_class_name}(db).get_data(data_id, v_schema=schema))\n" + base_code += "\n" + return base_code.replace("\t", " ")