优化核心类

This commit is contained in:
ktianc 2023-12-16 21:16:49 +08:00
parent 26ffb4c167
commit 518f9d4a47
4 changed files with 88 additions and 67 deletions

View File

@ -14,21 +14,21 @@
import datetime
from fastapi import HTTPException
from fastapi.encoders import jsonable_encoder
from sqlalchemy import func, delete, update, BinaryExpression, ScalarResult, select
from sqlalchemy import func, delete, update, BinaryExpression, ScalarResult, select, false, insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.orm.strategy_options import _AbstractLoad
from starlette import status
from core.exception import CustomException
from sqlalchemy.sql.selectable import Select as SelectType
from typing import Any
from typing import Any, List, Union
class DalBase:
# 倒叙
ORDER_FIELD = ["desc", "descending"]
def __init__(self, db: AsyncSession, model: Any, schema: Any):
def __init__(self, db: AsyncSession = None, model: Any = None, schema: Any = None):
self.db = db
self.model = model
self.schema = schema
@ -37,11 +37,11 @@ class DalBase:
self,
data_id: int = None,
v_start_sql: SelectType = None,
v_select_from: list[Any] = None,
v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None,
v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None,
v_options: list[_AbstractLoad] = None,
v_where: list[BinaryExpression] = None,
v_select_from: List[Any] = None,
v_join: List[Any] = None,
v_outer_join: List[Any] = None,
v_options: List[_AbstractLoad] = None,
v_where: List[BinaryExpression] = None,
v_order: str = None,
v_order_field: str = None,
v_return_none: bool = False,
@ -55,7 +55,7 @@ class DalBase:
:param v_start_sql: 初始 sql
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outerjoin: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_options: 用于为查询添加附加选项如预加载延迟加载等
:param v_where: 当前表查询条件原始表达式
:param v_order: 排序默认正序 desc 是倒叙
@ -66,16 +66,16 @@ class DalBase:
:return: 默认返回 ORM 对象如果存在 v_schema 则会返回 v_schema 结果
"""
if not isinstance(v_start_sql, SelectType):
v_start_sql = select(self.model).where(self.model.is_delete == False)
v_start_sql = select(self.model).where(self.model.is_delete == false())
if data_id:
if data_id is not None:
v_start_sql = v_start_sql.where(self.model.id == data_id)
queryset: ScalarResult = await self.filter_core(
v_start_sql=v_start_sql,
v_select_from=v_select_from,
v_join=v_join,
v_outerjoin=v_outerjoin,
v_outer_join=v_outer_join,
v_options=v_options,
v_where=v_where,
v_order=v_order,
@ -105,19 +105,20 @@ class DalBase:
page: int = 1,
limit: int = 10,
v_start_sql: SelectType = None,
v_select_from: list[Any] = None,
v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None,
v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None,
v_options: list[_AbstractLoad] = None,
v_where: list[BinaryExpression] = None,
v_select_from: List[Any] = None,
v_join: List[Any] = None,
v_outer_join: List[Any] = None,
v_options: List[_AbstractLoad] = None,
v_where: List[BinaryExpression] = None,
v_order: str = None,
v_order_field: str = None,
v_return_count: bool = False,
v_return_scalars: bool = False,
v_return_objs: bool = False,
v_schema: Any = None,
v_distinct: bool = False,
**kwargs
) -> list[Any] | ScalarResult | tuple:
) -> Union[List[Any], ScalarResult, tuple]:
"""
获取数据列表
@ -126,7 +127,7 @@ class DalBase:
:param v_start_sql: 初始 sql
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outerjoin: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_options: 用于为查询添加附加选项如预加载延迟加载等
:param v_where: 当前表查询条件原始表达式
:param v_order: 排序默认正序 desc 是倒叙
@ -135,6 +136,7 @@ class DalBase:
:param v_return_scalars: 返回scalars后的结果
:param v_return_objs: 是否返回对象
:param v_schema: 指定使用的序列化对象
:param v_distinct: 是否结果去重
:param kwargs: 查询参数使用的是自定义表达式
:return: 返回值优先级v_return_scalars > v_return_objs > v_schema
"""
@ -142,7 +144,7 @@ class DalBase:
v_start_sql=v_start_sql,
v_select_from=v_select_from,
v_join=v_join,
v_outerjoin=v_outerjoin,
v_outer_join=v_outer_join,
v_options=v_options,
v_where=v_where,
v_order=v_order,
@ -151,6 +153,9 @@ class DalBase:
**kwargs
)
if v_distinct:
sql = sql.distinct()
count = 0
if v_return_count:
count_sql = select(func.count()).select_from(sql.alias())
@ -184,10 +189,10 @@ class DalBase:
async def get_count(
self,
v_select_from: list[Any] = None,
v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None,
v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None,
v_where: list[BinaryExpression] = None,
v_select_from: List[Any] = None,
v_join: List[Any] = None,
v_outer_join: List[Any] = None,
v_where: List[BinaryExpression] = None,
**kwargs
) -> int:
"""
@ -195,7 +200,7 @@ class DalBase:
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outerjoin: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_where: 当前表查询条件原始表达式
:param kwargs: 查询参数
"""
@ -204,7 +209,7 @@ class DalBase:
v_start_sql=v_start_sql,
v_select_from=v_select_from,
v_join=v_join,
v_outerjoin=v_outerjoin,
v_outer_join=v_outer_join,
v_where=v_where,
v_return_sql=True,
**kwargs
@ -215,13 +220,12 @@ class DalBase:
async def create_data(
self,
data,
v_options: list[_AbstractLoad] = None,
v_options: List[_AbstractLoad] = None,
v_return_obj: bool = False,
v_schema: Any = None
) -> Any:
"""
创建单个数据
:param data: 创建数据
:param v_options: 指示应使用select在预加载中加载给定的属性
:param v_schema: 指定使用的序列化对象
@ -234,22 +238,22 @@ class DalBase:
await self.flush(obj)
return await self.out_dict(obj, v_options, v_return_obj, v_schema)
# async def create_datas(self, datas: list[dict]) -> None:
# """
# 批量创建数据,暂不启用
# SQLAlchemy 2.0 批量插入不支持 MySQL 返回对象
# https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#getting-new-objects-with-returning
#
# :param datas: 字典数据列表
# """
# await self.db.execute(insert(self.model), datas)
# await self.db.flush()
async def create_datas(self, datas: List[dict]) -> None:
"""
批量创建数据
SQLAlchemy 2.0 批量插入不支持 MySQL 返回值
https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#getting-new-objects-with-returning
:param datas: 字典数据列表
"""
await self.db.execute(insert(self.model), datas)
await self.db.flush()
async def put_data(
self,
data_id: int,
data: Any,
v_options: list[_AbstractLoad] = None,
v_options: List[_AbstractLoad] = None,
v_return_obj: bool = False,
v_schema: Any = None
) -> Any:
@ -268,7 +272,7 @@ class DalBase:
await self.flush(obj)
return await self.out_dict(obj, None, v_return_obj, v_schema)
async def delete_datas(self, ids: list[int], v_soft: bool = False, **kwargs) -> None:
async def delete_datas(self, ids: List[int], v_soft: bool = False, **kwargs) -> None:
"""
删除多条数据
:param ids: 数据集
@ -290,18 +294,21 @@ class DalBase:
async def flush(self, obj: Any = None) -> Any:
"""
刷新到数据库
:param obj:
:return:
"""
if obj:
self.db.add(obj)
await self.db.flush()
if obj:
# 使用 get_data 或者 get_datas 获取到实例后如何更新了实例,并需要序列化实例,那么需要执行 refresh 刷新才能正常序列化
await self.db.refresh(obj)
return obj
async def out_dict(
self,
obj: Any,
v_options: list[_AbstractLoad] = None,
v_options: List[_AbstractLoad] = None,
v_return_obj: bool = False,
v_schema: Any = None
) -> Any:
@ -324,23 +331,23 @@ class DalBase:
async def filter_core(
self,
v_start_sql: SelectType = None,
v_select_from: list[Any] = None,
v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None,
v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None,
v_options: list[_AbstractLoad] = None,
v_where: list[BinaryExpression] = None,
v_select_from: List[Any] = None,
v_join: List[Any] = None,
v_outer_join: List[Any] = None,
v_options: List[_AbstractLoad] = None,
v_where: List[BinaryExpression] = None,
v_order: str = None,
v_order_field: str = None,
v_return_sql: bool = False,
**kwargs
) -> ScalarResult | SelectType:
) -> Union[ScalarResult, SelectType]:
"""
数据过滤核心功能
:param v_start_sql: 初始 sql
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outerjoin: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_options: 用于为查询添加附加选项如预加载延迟加载等
:param v_where: 当前表查询条件原始表达式
:param v_order: 排序默认正序 desc 是倒叙
@ -349,13 +356,13 @@ class DalBase:
:return: 返回过滤后的总数居 sql
"""
if not isinstance(v_start_sql, SelectType):
v_start_sql = select(self.model).where(self.model.is_delete == False)
v_start_sql = select(self.model).where(self.model.is_delete == false())
sql = self.add_relation(
v_start_sql=v_start_sql,
v_select_from=v_select_from,
v_join=v_join,
v_outerjoin=v_outerjoin,
v_outer_join=v_outer_join,
v_options=v_options
)
@ -381,16 +388,16 @@ class DalBase:
def add_relation(
self,
v_start_sql: SelectType,
v_select_from: list[Any] = None,
v_join: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None,
v_outerjoin: list[list[str | InstrumentedAttribute, BinaryExpression | None]] = None,
v_options: list[_AbstractLoad] = None,
v_select_from: List[Any] = None,
v_join: List[Any] = None,
v_outer_join: List[Any] = None,
v_options: List[_AbstractLoad] = None,
) -> SelectType:
"""
:param v_start_sql: 初始 sql
:param v_select_from: 用于指定查询从哪个表开始通常与 .join() 等方法一起使用
:param v_join: 创建内连接INNER JOIN操作返回两个表中满足连接条件的交集
:param v_outerjoin: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_outer_join: 用于创建外连接OUTER JOIN操作返回两个表中满足连接条件的并集包括未匹配的行并用 NULL 值填充
:param v_options: 用于为查询添加附加选项如预加载延迟加载等
"""
if v_select_from:
@ -406,8 +413,8 @@ class DalBase:
else:
v_start_sql = v_start_sql.join(table)
if v_outerjoin:
for relation in v_outerjoin:
if v_outer_join:
for relation in v_outer_join:
table = relation[0]
if isinstance(table, str):
table = getattr(self.model, table)
@ -432,7 +439,7 @@ class DalBase:
sql = sql.where(*conditions)
return sql
def __dict_filter(self, **kwargs) -> list[BinaryExpression]:
def __dict_filter(self, **kwargs) -> List[BinaryExpression]:
"""
字典过滤
:param model:
@ -466,6 +473,8 @@ class DalBase:
conditions.append(attr != value[1])
elif value[0] == ">":
conditions.append(attr > value[1])
elif value[0] == ">=":
conditions.append(attr >= value[1])
elif value[0] == "<=":
conditions.append(attr <= value[1])
else:

View File

@ -17,7 +17,7 @@ from pydantic import AfterValidator, PlainSerializer, WithJsonSchema
from .validator import *
def DatetimeStrVali(value: str | datetime.datetime | int | float | dict):
def datetime_str_vali(value: str | datetime.datetime | int | float | dict):
"""
日期时间字符串验证
如果我传入的是字符串那么直接返回如果我传入的是一个日期类型那么会转为字符串格式后返回
@ -48,7 +48,7 @@ def DatetimeStrVali(value: str | datetime.datetime | int | float | dict):
# 实现自定义一个日期时间字符串的数据类型
DatetimeStr = Annotated[
str | datetime.datetime | int | float | dict,
AfterValidator(DatetimeStrVali),
AfterValidator(datetime_str_vali),
PlainSerializer(lambda x: x, return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization')
]
@ -72,7 +72,7 @@ Email = Annotated[
]
def DateStrVali(value: str | datetime.date | int | float):
def date_str_vali(value: str | datetime.date | int | float):
"""
日期字符串验证
如果我传入的是字符串那么直接返回如果我传入的是一个日期类型那么会转为字符串格式后返回
@ -95,13 +95,13 @@ def DateStrVali(value: str | datetime.date | int | float):
# 实现自定义一个日期字符串的数据类型
DateStr = Annotated[
str | datetime.date | int | float,
AfterValidator(DateStrVali),
AfterValidator(date_str_vali),
PlainSerializer(lambda x: x, return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization')
]
def ObjectIdStrVali(value: str | dict | ObjectId):
def object_id_str_vali(value: str | dict | ObjectId):
"""
官方文档https://docs.pydantic.dev/dev-v2/usage/types/datetime/
"""
@ -116,7 +116,7 @@ def ObjectIdStrVali(value: str | dict | ObjectId):
ObjectIdStr = Annotated[
Any, # 这里不能直接使用 any需要使用 typing.Any
AfterValidator(ObjectIdStrVali),
AfterValidator(object_id_str_vali),
PlainSerializer(lambda x: x, return_type=str),
WithJsonSchema({'type': 'string'}, mode='serialization')
]

View File

@ -24,4 +24,4 @@ class SuperEnum(Enum):
@classmethod
def values(cls):
"""Returns a list of all the enum values."""
return list(cls._value2member_map_.keys())
return list(cls._value2member_map_.keys())

View File

@ -22,7 +22,20 @@ class MongoManage:
# 倒叙
ORDER_FIELD = ["desc", "descending"]
def __init__(self, db: AsyncIOMotorDatabase, collection: str, schema: Any = None, is_object_id: bool = True):
def __init__(
self,
db: AsyncIOMotorDatabase = None,
collection: str = None,
schema: Any = None,
is_object_id: bool = True
):
"""
初始化
:param db:
:param collection: 集合
:param schema:
:param is_object_id: _id 列是否为 ObjectId 格式
"""
self.db = db
self.collection = db[collection]
self.schema = schema
@ -37,7 +50,6 @@ class MongoManage:
) -> dict | None:
"""
获取单个数据默认使用 ID 查询否则使用关键词查询
:param _id: 数据 ID
:param v_return_none: 是否返回空 None否则抛出异常默认抛出异常
:param v_schema: 指定使用的序列化对象