From 518f9d4a47de9a135155e1f6ceea6eac73b6114b Mon Sep 17 00:00:00 2001 From: ktianc Date: Sat, 16 Dec 2023 21:16:49 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=A0=B8=E5=BF=83=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- kinit-api/core/crud.py | 125 ++++++++++++++++++--------------- kinit-api/core/data_types.py | 12 ++-- kinit-api/core/enum.py | 2 +- kinit-api/core/mongo_manage.py | 16 ++++- 4 files changed, 88 insertions(+), 67 deletions(-) diff --git a/kinit-api/core/crud.py b/kinit-api/core/crud.py index 9b34d7e..7b5f3b6 100644 --- a/kinit-api/core/crud.py +++ b/kinit-api/core/crud.py @@ -14,21 +14,21 @@ import datetime from fastapi import HTTPException from fastapi.encoders import jsonable_encoder -from sqlalchemy import func, delete, update, BinaryExpression, ScalarResult, select +from sqlalchemy import func, delete, update, BinaryExpression, ScalarResult, select, false, insert from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.orm.strategy_options import _AbstractLoad from starlette import status from core.exception import CustomException from sqlalchemy.sql.selectable import Select as SelectType -from typing import Any +from typing import Any, List, Union class DalBase: # 倒叙 ORDER_FIELD = ["desc", "descending"] - def __init__(self, db: AsyncSession, model: Any, schema: Any): + def __init__(self, db: AsyncSession = None, model: Any = None, schema: Any = None): self.db = db self.model = model self.schema = schema @@ -37,11 +37,11 @@ class DalBase: self, data_id: int = None, v_start_sql: SelectType = None, - v_select_from: list[Any] = None, - v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, - v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, - v_options: list[_AbstractLoad] = None, - v_where: list[BinaryExpression] = None, + v_select_from: List[Any] = None, + v_join: List[Any] = None, + v_outer_join: List[Any] = None, + v_options: List[_AbstractLoad] = None, + v_where: List[BinaryExpression] = None, v_order: str = None, v_order_field: str = None, v_return_none: bool = False, @@ -55,7 +55,7 @@ class DalBase: :param v_start_sql: 初始 sql :param v_select_from: 用于指定查询从哪个表开始,通常与 .join() 等方法一起使用。 :param v_join: 创建内连接(INNER JOIN)操作,返回两个表中满足连接条件的交集。 - :param v_outerjoin: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。 + :param v_outer_join: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。 :param v_options: 用于为查询添加附加选项,如预加载、延迟加载等。 :param v_where: 当前表查询条件,原始表达式 :param v_order: 排序,默认正序,为 desc 是倒叙 @@ -66,16 +66,16 @@ class DalBase: :return: 默认返回 ORM 对象,如果存在 v_schema 则会返回 v_schema 结果 """ if not isinstance(v_start_sql, SelectType): - v_start_sql = select(self.model).where(self.model.is_delete == False) + v_start_sql = select(self.model).where(self.model.is_delete == false()) - if data_id: + if data_id is not None: v_start_sql = v_start_sql.where(self.model.id == data_id) queryset: ScalarResult = await self.filter_core( v_start_sql=v_start_sql, v_select_from=v_select_from, v_join=v_join, - v_outerjoin=v_outerjoin, + v_outer_join=v_outer_join, v_options=v_options, v_where=v_where, v_order=v_order, @@ -105,19 +105,20 @@ class DalBase: page: int = 1, limit: int = 10, v_start_sql: SelectType = None, - v_select_from: list[Any] = None, - v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, - v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, - v_options: list[_AbstractLoad] = None, - v_where: list[BinaryExpression] = None, + v_select_from: List[Any] = None, + v_join: List[Any] = None, + v_outer_join: List[Any] = None, + v_options: List[_AbstractLoad] = None, + v_where: List[BinaryExpression] = None, v_order: str = None, v_order_field: str = None, v_return_count: bool = False, v_return_scalars: bool = False, v_return_objs: bool = False, v_schema: Any = None, + v_distinct: bool = False, **kwargs - ) -> list[Any] | ScalarResult | tuple: + ) -> Union[List[Any], ScalarResult, tuple]: """ 获取数据列表 @@ -126,7 +127,7 @@ class DalBase: :param v_start_sql: 初始 sql :param v_select_from: 用于指定查询从哪个表开始,通常与 .join() 等方法一起使用。 :param v_join: 创建内连接(INNER JOIN)操作,返回两个表中满足连接条件的交集。 - :param v_outerjoin: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。 + :param v_outer_join: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。 :param v_options: 用于为查询添加附加选项,如预加载、延迟加载等。 :param v_where: 当前表查询条件,原始表达式 :param v_order: 排序,默认正序,为 desc 是倒叙 @@ -135,6 +136,7 @@ class DalBase: :param v_return_scalars: 返回scalars后的结果 :param v_return_objs: 是否返回对象 :param v_schema: 指定使用的序列化对象 + :param v_distinct: 是否结果去重 :param kwargs: 查询参数,使用的是自定义表达式 :return: 返回值优先级:v_return_scalars > v_return_objs > v_schema """ @@ -142,7 +144,7 @@ class DalBase: v_start_sql=v_start_sql, v_select_from=v_select_from, v_join=v_join, - v_outerjoin=v_outerjoin, + v_outer_join=v_outer_join, v_options=v_options, v_where=v_where, v_order=v_order, @@ -151,6 +153,9 @@ class DalBase: **kwargs ) + if v_distinct: + sql = sql.distinct() + count = 0 if v_return_count: count_sql = select(func.count()).select_from(sql.alias()) @@ -184,10 +189,10 @@ class DalBase: async def get_count( self, - v_select_from: list[Any] = None, - v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, - v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, - v_where: list[BinaryExpression] = None, + v_select_from: List[Any] = None, + v_join: List[Any] = None, + v_outer_join: List[Any] = None, + v_where: List[BinaryExpression] = None, **kwargs ) -> int: """ @@ -195,7 +200,7 @@ class DalBase: :param v_select_from: 用于指定查询从哪个表开始,通常与 .join() 等方法一起使用。 :param v_join: 创建内连接(INNER JOIN)操作,返回两个表中满足连接条件的交集。 - :param v_outerjoin: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。 + :param v_outer_join: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。 :param v_where: 当前表查询条件,原始表达式 :param kwargs: 查询参数 """ @@ -204,7 +209,7 @@ class DalBase: v_start_sql=v_start_sql, v_select_from=v_select_from, v_join=v_join, - v_outerjoin=v_outerjoin, + v_outer_join=v_outer_join, v_where=v_where, v_return_sql=True, **kwargs @@ -215,13 +220,12 @@ class DalBase: async def create_data( self, data, - v_options: list[_AbstractLoad] = None, + v_options: List[_AbstractLoad] = None, v_return_obj: bool = False, v_schema: Any = None ) -> Any: """ 创建单个数据 - :param data: 创建数据 :param v_options: 指示应使用select在预加载中加载给定的属性。 :param v_schema: ,指定使用的序列化对象 @@ -234,22 +238,22 @@ class DalBase: await self.flush(obj) return await self.out_dict(obj, v_options, v_return_obj, v_schema) - # async def create_datas(self, datas: list[dict]) -> None: - # """ - # 批量创建数据,暂不启用 - # SQLAlchemy 2.0 批量插入不支持 MySQL 返回对象: - # https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#getting-new-objects-with-returning - # - # :param datas: 字典数据列表 - # """ - # await self.db.execute(insert(self.model), datas) - # await self.db.flush() + async def create_datas(self, datas: List[dict]) -> None: + """ + 批量创建数据 + SQLAlchemy 2.0 批量插入不支持 MySQL 返回值: + https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#getting-new-objects-with-returning + + :param datas: 字典数据列表 + """ + await self.db.execute(insert(self.model), datas) + await self.db.flush() async def put_data( self, data_id: int, data: Any, - v_options: list[_AbstractLoad] = None, + v_options: List[_AbstractLoad] = None, v_return_obj: bool = False, v_schema: Any = None ) -> Any: @@ -268,7 +272,7 @@ class DalBase: await self.flush(obj) return await self.out_dict(obj, None, v_return_obj, v_schema) - async def delete_datas(self, ids: list[int], v_soft: bool = False, **kwargs) -> None: + async def delete_datas(self, ids: List[int], v_soft: bool = False, **kwargs) -> None: """ 删除多条数据 :param ids: 数据集 @@ -290,18 +294,21 @@ class DalBase: async def flush(self, obj: Any = None) -> Any: """ 刷新到数据库 + :param obj: + :return: """ if obj: self.db.add(obj) await self.db.flush() if obj: + # 使用 get_data 或者 get_datas 获取到实例后如何更新了实例,并需要序列化实例,那么需要执行 refresh 刷新才能正常序列化 await self.db.refresh(obj) return obj async def out_dict( self, obj: Any, - v_options: list[_AbstractLoad] = None, + v_options: List[_AbstractLoad] = None, v_return_obj: bool = False, v_schema: Any = None ) -> Any: @@ -324,23 +331,23 @@ class DalBase: async def filter_core( self, v_start_sql: SelectType = None, - v_select_from: list[Any] = None, - v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, - v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, - v_options: list[_AbstractLoad] = None, - v_where: list[BinaryExpression] = None, + v_select_from: List[Any] = None, + v_join: List[Any] = None, + v_outer_join: List[Any] = None, + v_options: List[_AbstractLoad] = None, + v_where: List[BinaryExpression] = None, v_order: str = None, v_order_field: str = None, v_return_sql: bool = False, **kwargs - ) -> ScalarResult | SelectType: + ) -> Union[ScalarResult, SelectType]: """ 数据过滤核心功能 :param v_start_sql: 初始 sql :param v_select_from: 用于指定查询从哪个表开始,通常与 .join() 等方法一起使用。 :param v_join: 创建内连接(INNER JOIN)操作,返回两个表中满足连接条件的交集。 - :param v_outerjoin: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。 + :param v_outer_join: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。 :param v_options: 用于为查询添加附加选项,如预加载、延迟加载等。 :param v_where: 当前表查询条件,原始表达式 :param v_order: 排序,默认正序,为 desc 是倒叙 @@ -349,13 +356,13 @@ class DalBase: :return: 返回过滤后的总数居 或 sql """ if not isinstance(v_start_sql, SelectType): - v_start_sql = select(self.model).where(self.model.is_delete == False) + v_start_sql = select(self.model).where(self.model.is_delete == false()) sql = self.add_relation( v_start_sql=v_start_sql, v_select_from=v_select_from, v_join=v_join, - v_outerjoin=v_outerjoin, + v_outer_join=v_outer_join, v_options=v_options ) @@ -381,16 +388,16 @@ class DalBase: def add_relation( self, v_start_sql: SelectType, - v_select_from: list[Any] = None, - v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, - v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, - v_options: list[_AbstractLoad] = None, + v_select_from: List[Any] = None, + v_join: List[Any] = None, + v_outer_join: List[Any] = None, + v_options: List[_AbstractLoad] = None, ) -> SelectType: """ :param v_start_sql: 初始 sql :param v_select_from: 用于指定查询从哪个表开始,通常与 .join() 等方法一起使用。 :param v_join: 创建内连接(INNER JOIN)操作,返回两个表中满足连接条件的交集。 - :param v_outerjoin: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。 + :param v_outer_join: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。 :param v_options: 用于为查询添加附加选项,如预加载、延迟加载等。 """ if v_select_from: @@ -406,8 +413,8 @@ class DalBase: else: v_start_sql = v_start_sql.join(table) - if v_outerjoin: - for relation in v_outerjoin: + if v_outer_join: + for relation in v_outer_join: table = relation[0] if isinstance(table, str): table = getattr(self.model, table) @@ -432,7 +439,7 @@ class DalBase: sql = sql.where(*conditions) return sql - def __dict_filter(self, **kwargs) -> list[BinaryExpression]: + def __dict_filter(self, **kwargs) -> List[BinaryExpression]: """ 字典过滤 :param model: @@ -466,6 +473,8 @@ class DalBase: conditions.append(attr != value[1]) elif value[0] == ">": conditions.append(attr > value[1]) + elif value[0] == ">=": + conditions.append(attr >= value[1]) elif value[0] == "<=": conditions.append(attr <= value[1]) else: diff --git a/kinit-api/core/data_types.py b/kinit-api/core/data_types.py index 2b26762..243975e 100644 --- a/kinit-api/core/data_types.py +++ b/kinit-api/core/data_types.py @@ -17,7 +17,7 @@ from pydantic import AfterValidator, PlainSerializer, WithJsonSchema from .validator import * -def DatetimeStrVali(value: str | datetime.datetime | int | float | dict): +def datetime_str_vali(value: str | datetime.datetime | int | float | dict): """ 日期时间字符串验证 如果我传入的是字符串,那么直接返回,如果我传入的是一个日期类型,那么会转为字符串格式后返回 @@ -48,7 +48,7 @@ def DatetimeStrVali(value: str | datetime.datetime | int | float | dict): # 实现自定义一个日期时间字符串的数据类型 DatetimeStr = Annotated[ str | datetime.datetime | int | float | dict, - AfterValidator(DatetimeStrVali), + AfterValidator(datetime_str_vali), PlainSerializer(lambda x: x, return_type=str), WithJsonSchema({'type': 'string'}, mode='serialization') ] @@ -72,7 +72,7 @@ Email = Annotated[ ] -def DateStrVali(value: str | datetime.date | int | float): +def date_str_vali(value: str | datetime.date | int | float): """ 日期字符串验证 如果我传入的是字符串,那么直接返回,如果我传入的是一个日期类型,那么会转为字符串格式后返回 @@ -95,13 +95,13 @@ def DateStrVali(value: str | datetime.date | int | float): # 实现自定义一个日期字符串的数据类型 DateStr = Annotated[ str | datetime.date | int | float, - AfterValidator(DateStrVali), + AfterValidator(date_str_vali), PlainSerializer(lambda x: x, return_type=str), WithJsonSchema({'type': 'string'}, mode='serialization') ] -def ObjectIdStrVali(value: str | dict | ObjectId): +def object_id_str_vali(value: str | dict | ObjectId): """ 官方文档:https://docs.pydantic.dev/dev-v2/usage/types/datetime/ """ @@ -116,7 +116,7 @@ def ObjectIdStrVali(value: str | dict | ObjectId): ObjectIdStr = Annotated[ Any, # 这里不能直接使用 any,需要使用 typing.Any - AfterValidator(ObjectIdStrVali), + AfterValidator(object_id_str_vali), PlainSerializer(lambda x: x, return_type=str), WithJsonSchema({'type': 'string'}, mode='serialization') ] diff --git a/kinit-api/core/enum.py b/kinit-api/core/enum.py index 8fac768..47b2c5d 100644 --- a/kinit-api/core/enum.py +++ b/kinit-api/core/enum.py @@ -24,4 +24,4 @@ class SuperEnum(Enum): @classmethod def values(cls): """Returns a list of all the enum values.""" - return list(cls._value2member_map_.keys()) \ No newline at end of file + return list(cls._value2member_map_.keys()) diff --git a/kinit-api/core/mongo_manage.py b/kinit-api/core/mongo_manage.py index a88a2ac..a0d6549 100644 --- a/kinit-api/core/mongo_manage.py +++ b/kinit-api/core/mongo_manage.py @@ -22,7 +22,20 @@ class MongoManage: # 倒叙 ORDER_FIELD = ["desc", "descending"] - def __init__(self, db: AsyncIOMotorDatabase, collection: str, schema: Any = None, is_object_id: bool = True): + def __init__( + self, + db: AsyncIOMotorDatabase = None, + collection: str = None, + schema: Any = None, + is_object_id: bool = True + ): + """ + 初始化 + :param db: + :param collection: 集合 + :param schema: + :param is_object_id: _id 列是否为 ObjectId 格式 + """ self.db = db self.collection = db[collection] self.schema = schema @@ -37,7 +50,6 @@ class MongoManage: ) -> dict | None: """ 获取单个数据,默认使用 ID 查询,否则使用关键词查询 - :param _id: 数据 ID :param v_return_none: 是否返回空 None,否则抛出异常,默认抛出异常 :param v_schema: 指定使用的序列化对象