优化核心类

This commit is contained in:
ktianc 2023-12-16 21:16:49 +08:00
parent 26ffb4c167
commit 518f9d4a47
4 changed files with 88 additions and 67 deletions

View File

@ -14,21 +14,21 @@
import datetime import datetime
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from sqlalchemy import func, delete, update, BinaryExpression, ScalarResult, select from sqlalchemy import func, delete, update, BinaryExpression, ScalarResult, select, false, insert
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.orm.strategy_options import _AbstractLoad from sqlalchemy.orm.strategy_options import _AbstractLoad
from starlette import status from starlette import status
from core.exception import CustomException from core.exception import CustomException
from sqlalchemy.sql.selectable import Select as SelectType from sqlalchemy.sql.selectable import Select as SelectType
from typing import Any from typing import Any, List, Union
class DalBase: class DalBase:
# 倒叙 # 倒叙
ORDER_FIELD = ["desc", "descending"] ORDER_FIELD = ["desc", "descending"]
def __init__(self, db: AsyncSession, model: Any, schema: Any): def __init__(self, db: AsyncSession = None, model: Any = None, schema: Any = None):
self.db = db self.db = db
self.model = model self.model = model
self.schema = schema self.schema = schema
@ -37,11 +37,11 @@ class DalBase:
self, self,
data_id: int = None, data_id: int = None,
v_start_sql: SelectType = None, v_start_sql: SelectType = None,
v_select_from: list[Any] = None, v_select_from: List[Any] = None,
v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, v_join: List[Any] = None,
v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, v_outer_join: List[Any] = None,
v_options: list[_AbstractLoad] = None, v_options: List[_AbstractLoad] = None,
v_where: list[BinaryExpression] = None, v_where: List[BinaryExpression] = None,
v_order: str = None, v_order: str = None,
v_order_field: str = None, v_order_field: str = None,
v_return_none: bool = False, v_return_none: bool = False,
@ -55,7 +55,7 @@ class DalBase:
:param v_start_sql: 初始 sql :param v_start_sql: 初始 sql
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用 :param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集 :param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outerjoin: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充 :param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_options: 用于为查询添加附加选项如预加载延迟加载等 :param v_options: 用于为查询添加附加选项如预加载延迟加载等
:param v_where: 当前表查询条件原始表达式 :param v_where: 当前表查询条件原始表达式
:param v_order: 排序默认正序 desc 是倒叙 :param v_order: 排序默认正序 desc 是倒叙
@ -66,16 +66,16 @@ class DalBase:
:return: 默认返回 ORM 对象如果存在 v_schema 则会返回 v_schema 结果 :return: 默认返回 ORM 对象如果存在 v_schema 则会返回 v_schema 结果
""" """
if not isinstance(v_start_sql, SelectType): if not isinstance(v_start_sql, SelectType):
v_start_sql = select(self.model).where(self.model.is_delete == False) v_start_sql = select(self.model).where(self.model.is_delete == false())
if data_id: if data_id is not None:
v_start_sql = v_start_sql.where(self.model.id == data_id) v_start_sql = v_start_sql.where(self.model.id == data_id)
queryset: ScalarResult = await self.filter_core( queryset: ScalarResult = await self.filter_core(
v_start_sql=v_start_sql, v_start_sql=v_start_sql,
v_select_from=v_select_from, v_select_from=v_select_from,
v_join=v_join, v_join=v_join,
v_outerjoin=v_outerjoin, v_outer_join=v_outer_join,
v_options=v_options, v_options=v_options,
v_where=v_where, v_where=v_where,
v_order=v_order, v_order=v_order,
@ -105,19 +105,20 @@ class DalBase:
page: int = 1, page: int = 1,
limit: int = 10, limit: int = 10,
v_start_sql: SelectType = None, v_start_sql: SelectType = None,
v_select_from: list[Any] = None, v_select_from: List[Any] = None,
v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, v_join: List[Any] = None,
v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, v_outer_join: List[Any] = None,
v_options: list[_AbstractLoad] = None, v_options: List[_AbstractLoad] = None,
v_where: list[BinaryExpression] = None, v_where: List[BinaryExpression] = None,
v_order: str = None, v_order: str = None,
v_order_field: str = None, v_order_field: str = None,
v_return_count: bool = False, v_return_count: bool = False,
v_return_scalars: bool = False, v_return_scalars: bool = False,
v_return_objs: bool = False, v_return_objs: bool = False,
v_schema: Any = None, v_schema: Any = None,
v_distinct: bool = False,
**kwargs **kwargs
) -> list[Any] | ScalarResult | tuple: ) -> Union[List[Any], ScalarResult, tuple]:
""" """
获取数据列表 获取数据列表
@ -126,7 +127,7 @@ class DalBase:
:param v_start_sql: 初始 sql :param v_start_sql: 初始 sql
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用 :param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集 :param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outerjoin: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充 :param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_options: 用于为查询添加附加选项如预加载延迟加载等 :param v_options: 用于为查询添加附加选项如预加载延迟加载等
:param v_where: 当前表查询条件原始表达式 :param v_where: 当前表查询条件原始表达式
:param v_order: 排序默认正序 desc 是倒叙 :param v_order: 排序默认正序 desc 是倒叙
@ -135,6 +136,7 @@ class DalBase:
:param v_return_scalars: 返回scalars后的结果 :param v_return_scalars: 返回scalars后的结果
:param v_return_objs: 是否返回对象 :param v_return_objs: 是否返回对象
:param v_schema: 指定使用的序列化对象 :param v_schema: 指定使用的序列化对象
:param v_distinct: 是否结果去重
:param kwargs: 查询参数使用的是自定义表达式 :param kwargs: 查询参数使用的是自定义表达式
:return: 返回值优先级v_return_scalars > v_return_objs > v_schema :return: 返回值优先级v_return_scalars > v_return_objs > v_schema
""" """
@ -142,7 +144,7 @@ class DalBase:
v_start_sql=v_start_sql, v_start_sql=v_start_sql,
v_select_from=v_select_from, v_select_from=v_select_from,
v_join=v_join, v_join=v_join,
v_outerjoin=v_outerjoin, v_outer_join=v_outer_join,
v_options=v_options, v_options=v_options,
v_where=v_where, v_where=v_where,
v_order=v_order, v_order=v_order,
@ -151,6 +153,9 @@ class DalBase:
**kwargs **kwargs
) )
if v_distinct:
sql = sql.distinct()
count = 0 count = 0
if v_return_count: if v_return_count:
count_sql = select(func.count()).select_from(sql.alias()) count_sql = select(func.count()).select_from(sql.alias())
@ -184,10 +189,10 @@ class DalBase:
async def get_count( async def get_count(
self, self,
v_select_from: list[Any] = None, v_select_from: List[Any] = None,
v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, v_join: List[Any] = None,
v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, v_outer_join: List[Any] = None,
v_where: list[BinaryExpression] = None, v_where: List[BinaryExpression] = None,
**kwargs **kwargs
) -> int: ) -> int:
""" """
@ -195,7 +200,7 @@ class DalBase:
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用 :param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集 :param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outerjoin: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充 :param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_where: 当前表查询条件原始表达式 :param v_where: 当前表查询条件原始表达式
:param kwargs: 查询参数 :param kwargs: 查询参数
""" """
@ -204,7 +209,7 @@ class DalBase:
v_start_sql=v_start_sql, v_start_sql=v_start_sql,
v_select_from=v_select_from, v_select_from=v_select_from,
v_join=v_join, v_join=v_join,
v_outerjoin=v_outerjoin, v_outer_join=v_outer_join,
v_where=v_where, v_where=v_where,
v_return_sql=True, v_return_sql=True,
**kwargs **kwargs
@ -215,13 +220,12 @@ class DalBase:
async def create_data( async def create_data(
self, self,
data, data,
v_options: list[_AbstractLoad] = None, v_options: List[_AbstractLoad] = None,
v_return_obj: bool = False, v_return_obj: bool = False,
v_schema: Any = None v_schema: Any = None
) -> Any: ) -> Any:
""" """
创建单个数据 创建单个数据
:param data: 创建数据 :param data: 创建数据
:param v_options: 指示应使用select在预加载中加载给定的属性 :param v_options: 指示应使用select在预加载中加载给定的属性
:param v_schema: 指定使用的序列化对象 :param v_schema: 指定使用的序列化对象
@ -234,22 +238,22 @@ class DalBase:
await self.flush(obj) await self.flush(obj)
return await self.out_dict(obj, v_options, v_return_obj, v_schema) return await self.out_dict(obj, v_options, v_return_obj, v_schema)
# async def create_datas(self, datas: list[dict]) -> None: async def create_datas(self, datas: List[dict]) -> None:
# """ """
# 批量创建数据,暂不启用 批量创建数据
# SQLAlchemy 2.0 批量插入不支持 MySQL 返回对象 SQLAlchemy 2.0 批量插入不支持 MySQL 返回值
# https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#getting-new-objects-with-returning https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#getting-new-objects-with-returning
#
# :param datas: 字典数据列表 :param datas: 字典数据列表
# """ """
# await self.db.execute(insert(self.model), datas) await self.db.execute(insert(self.model), datas)
# await self.db.flush() await self.db.flush()
async def put_data( async def put_data(
self, self,
data_id: int, data_id: int,
data: Any, data: Any,
v_options: list[_AbstractLoad] = None, v_options: List[_AbstractLoad] = None,
v_return_obj: bool = False, v_return_obj: bool = False,
v_schema: Any = None v_schema: Any = None
) -> Any: ) -> Any:
@ -268,7 +272,7 @@ class DalBase:
await self.flush(obj) await self.flush(obj)
return await self.out_dict(obj, None, v_return_obj, v_schema) return await self.out_dict(obj, None, v_return_obj, v_schema)
async def delete_datas(self, ids: list[int], v_soft: bool = False, **kwargs) -> None: async def delete_datas(self, ids: List[int], v_soft: bool = False, **kwargs) -> None:
""" """
删除多条数据 删除多条数据
:param ids: 数据集 :param ids: 数据集
@ -290,18 +294,21 @@ class DalBase:
async def flush(self, obj: Any = None) -> Any: async def flush(self, obj: Any = None) -> Any:
""" """
刷新到数据库 刷新到数据库
:param obj:
:return:
""" """
if obj: if obj:
self.db.add(obj) self.db.add(obj)
await self.db.flush() await self.db.flush()
if obj: if obj:
# 使用 get_data 或者 get_datas 获取到实例后如何更新了实例,并需要序列化实例,那么需要执行 refresh 刷新才能正常序列化
await self.db.refresh(obj) await self.db.refresh(obj)
return obj return obj
async def out_dict( async def out_dict(
self, self,
obj: Any, obj: Any,
v_options: list[_AbstractLoad] = None, v_options: List[_AbstractLoad] = None,
v_return_obj: bool = False, v_return_obj: bool = False,
v_schema: Any = None v_schema: Any = None
) -> Any: ) -> Any:
@ -324,23 +331,23 @@ class DalBase:
async def filter_core( async def filter_core(
self, self,
v_start_sql: SelectType = None, v_start_sql: SelectType = None,
v_select_from: list[Any] = None, v_select_from: List[Any] = None,
v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, v_join: List[Any] = None,
v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, v_outer_join: List[Any] = None,
v_options: list[_AbstractLoad] = None, v_options: List[_AbstractLoad] = None,
v_where: list[BinaryExpression] = None, v_where: List[BinaryExpression] = None,
v_order: str = None, v_order: str = None,
v_order_field: str = None, v_order_field: str = None,
v_return_sql: bool = False, v_return_sql: bool = False,
**kwargs **kwargs
) -> ScalarResult | SelectType: ) -> Union[ScalarResult, SelectType]:
""" """
数据过滤核心功能 数据过滤核心功能
:param v_start_sql: 初始 sql :param v_start_sql: 初始 sql
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用 :param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集 :param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outerjoin: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充 :param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_options: 用于为查询添加附加选项如预加载延迟加载等 :param v_options: 用于为查询添加附加选项如预加载延迟加载等
:param v_where: 当前表查询条件原始表达式 :param v_where: 当前表查询条件原始表达式
:param v_order: 排序默认正序 desc 是倒叙 :param v_order: 排序默认正序 desc 是倒叙
@ -349,13 +356,13 @@ class DalBase:
:return: 返回过滤后的总数居 sql :return: 返回过滤后的总数居 sql
""" """
if not isinstance(v_start_sql, SelectType): if not isinstance(v_start_sql, SelectType):
v_start_sql = select(self.model).where(self.model.is_delete == False) v_start_sql = select(self.model).where(self.model.is_delete == false())
sql = self.add_relation( sql = self.add_relation(
v_start_sql=v_start_sql, v_start_sql=v_start_sql,
v_select_from=v_select_from, v_select_from=v_select_from,
v_join=v_join, v_join=v_join,
v_outerjoin=v_outerjoin, v_outer_join=v_outer_join,
v_options=v_options v_options=v_options
) )
@ -381,16 +388,16 @@ class DalBase:
def add_relation( def add_relation(
self, self,
v_start_sql: SelectType, v_start_sql: SelectType,
v_select_from: list[Any] = None, v_select_from: List[Any] = None,
v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, v_join: List[Any] = None,
v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None, v_outer_join: List[Any] = None,
v_options: list[_AbstractLoad] = None, v_options: List[_AbstractLoad] = None,
) -> SelectType: ) -> SelectType:
""" """
:param v_start_sql: 初始 sql :param v_start_sql: 初始 sql
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用 :param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集 :param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outerjoin: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充 :param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_options: 用于为查询添加附加选项如预加载延迟加载等 :param v_options: 用于为查询添加附加选项如预加载延迟加载等
""" """
if v_select_from: if v_select_from:
@ -406,8 +413,8 @@ class DalBase:
else: else:
v_start_sql = v_start_sql.join(table) v_start_sql = v_start_sql.join(table)
if v_outerjoin: if v_outer_join:
for relation in v_outerjoin: for relation in v_outer_join:
table = relation[0] table = relation[0]
if isinstance(table, str): if isinstance(table, str):
table = getattr(self.model, table) table = getattr(self.model, table)
@ -432,7 +439,7 @@ class DalBase:
sql = sql.where(*conditions) sql = sql.where(*conditions)
return sql return sql
def __dict_filter(self, **kwargs) -> list[BinaryExpression]: def __dict_filter(self, **kwargs) -> List[BinaryExpression]:
""" """
字典过滤 字典过滤
:param model: :param model:
@ -466,6 +473,8 @@ class DalBase:
conditions.append(attr != value[1]) conditions.append(attr != value[1])
elif value[0] == ">": elif value[0] == ">":
conditions.append(attr > value[1]) conditions.append(attr > value[1])
elif value[0] == ">=":
conditions.append(attr >= value[1])
elif value[0] == "<=": elif value[0] == "<=":
conditions.append(attr <= value[1]) conditions.append(attr <= value[1])
else: else:

View File

@ -17,7 +17,7 @@ from pydantic import AfterValidator, PlainSerializer, WithJsonSchema
from .validator import * from .validator import *
def DatetimeStrVali(value: str | datetime.datetime | int | float | dict): def datetime_str_vali(value: str | datetime.datetime | int | float | dict):
""" """
日期时间字符串验证 日期时间字符串验证
如果我传入的是字符串那么直接返回如果我传入的是一个日期类型那么会转为字符串格式后返回 如果我传入的是字符串那么直接返回如果我传入的是一个日期类型那么会转为字符串格式后返回
@ -48,7 +48,7 @@ def DatetimeStrVali(value: str | datetime.datetime | int | float | dict):
# 实现自定义一个日期时间字符串的数据类型 # 实现自定义一个日期时间字符串的数据类型
DatetimeStr = Annotated[ DatetimeStr = Annotated[
str | datetime.datetime | int | float | dict, str | datetime.datetime | int | float | dict,
AfterValidator(DatetimeStrVali), AfterValidator(datetime_str_vali),
PlainSerializer(lambda x: x, return_type=str), PlainSerializer(lambda x: x, return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization') WithJsonSchema({'type': 'string'}, mode='serialization')
] ]
@ -72,7 +72,7 @@ Email = Annotated[
] ]
def DateStrVali(value: str | datetime.date | int | float): def date_str_vali(value: str | datetime.date | int | float):
""" """
日期字符串验证 日期字符串验证
如果我传入的是字符串那么直接返回如果我传入的是一个日期类型那么会转为字符串格式后返回 如果我传入的是字符串那么直接返回如果我传入的是一个日期类型那么会转为字符串格式后返回
@ -95,13 +95,13 @@ def DateStrVali(value: str | datetime.date | int | float):
# 实现自定义一个日期字符串的数据类型 # 实现自定义一个日期字符串的数据类型
DateStr = Annotated[ DateStr = Annotated[
str | datetime.date | int | float, str | datetime.date | int | float,
AfterValidator(DateStrVali), AfterValidator(date_str_vali),
PlainSerializer(lambda x: x, return_type=str), PlainSerializer(lambda x: x, return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization') WithJsonSchema({'type': 'string'}, mode='serialization')
] ]
def ObjectIdStrVali(value: str | dict | ObjectId): def object_id_str_vali(value: str | dict | ObjectId):
""" """
官方文档https://docs.pydantic.dev/dev-v2/usage/types/datetime/ 官方文档https://docs.pydantic.dev/dev-v2/usage/types/datetime/
""" """
@ -116,7 +116,7 @@ def ObjectIdStrVali(value: str | dict | ObjectId):
ObjectIdStr = Annotated[ ObjectIdStr = Annotated[
Any, # 这里不能直接使用 any需要使用 typing.Any Any, # 这里不能直接使用 any需要使用 typing.Any
AfterValidator(ObjectIdStrVali), AfterValidator(object_id_str_vali),
PlainSerializer(lambda x: x, return_type=str), PlainSerializer(lambda x: x, return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization') WithJsonSchema({'type': 'string'}, mode='serialization')
] ]

View File

@ -24,4 +24,4 @@ class SuperEnum(Enum):
@classmethod @classmethod
def values(cls): def values(cls):
"""Returns a list of all the enum values.""" """Returns a list of all the enum values."""
return list(cls._value2member_map_.keys()) return list(cls._value2member_map_.keys())

View File

@ -22,7 +22,20 @@ class MongoManage:
# 倒叙 # 倒叙
ORDER_FIELD = ["desc", "descending"] ORDER_FIELD = ["desc", "descending"]
def __init__(self, db: AsyncIOMotorDatabase, collection: str, schema: Any = None, is_object_id: bool = True): def __init__(
self,
db: AsyncIOMotorDatabase = None,
collection: str = None,
schema: Any = None,
is_object_id: bool = True
):
"""
初始化
:param db:
:param collection: 集合
:param schema:
:param is_object_id: _id 列是否为 ObjectId 格式
"""
self.db = db self.db = db
self.collection = db[collection] self.collection = db[collection]
self.schema = schema self.schema = schema
@ -37,7 +50,6 @@ class MongoManage:
) -> dict | None: ) -> dict | None:
""" """
获取单个数据默认使用 ID 查询否则使用关键词查询 获取单个数据默认使用 ID 查询否则使用关键词查询
:param _id: 数据 ID :param _id: 数据 ID
:param v_return_none: 是否返回空 None否则抛出异常默认抛出异常 :param v_return_none: 是否返回空 None否则抛出异常默认抛出异常
:param v_schema: 指定使用的序列化对象 :param v_schema: 指定使用的序列化对象