版本更新:

1. 优化(kinit-api):core/crud.py 增删改查操作基类优化
2. 更新(kinit-api):更新 core/dependencies.py 查询参数基类的 dict,to_count方法,新增排除参数
This commit is contained in:
ktianc 2023-03-18 14:32:42 +08:00
parent 94954f362c
commit e8979577bd
5 changed files with 54 additions and 66 deletions

View File

@ -11,7 +11,7 @@ from fastapi.security import OAuth2PasswordBearer
""" """
系统版本 系统版本
""" """
VERSION = "1.6.3" VERSION = "1.6.4"
"""安全警告: 不要在生产中打开调试运行!""" """安全警告: 不要在生产中打开调试运行!"""
DEBUG = True DEBUG = True

View File

@ -66,13 +66,7 @@ class UserDal(DalBase):
for role in roles: for role in roles:
obj.roles.append(role) obj.roles.append(role)
await self.flush(obj) await self.flush(obj)
if v_options: return await self.out_dict(obj, v_options, v_return_obj, v_schema)
obj = await self.get_data(obj.id, v_options=v_options)
if v_return_obj:
return obj
if v_schema:
return v_schema.from_orm(obj).dict()
return self.out_dict(obj)
async def put_data( async def put_data(
self, self,
@ -97,13 +91,8 @@ class UserDal(DalBase):
obj.roles.append(role) obj.roles.append(role)
continue continue
setattr(obj, key, value) setattr(obj, key, value)
await self.db.flush() await self.flush(obj)
await self.db.refresh(obj) return await self.out_dict(obj, None, v_return_obj, v_schema)
if v_return_obj:
return obj
if v_schema:
return v_schema.from_orm(obj).dict()
return self.out_dict(obj)
async def reset_current_password(self, user: models.VadminUser, data: schemas.ResetPwd): async def reset_current_password(self, user: models.VadminUser, data: schemas.ResetPwd):
""" """
@ -133,7 +122,7 @@ class UserDal(DalBase):
user.nickname = data.nickname user.nickname = data.nickname
user.gender = data.gender user.gender = data.gender
await self.flush(user) await self.flush(user)
return self.out_dict(user) return await self.out_dict(user)
async def export_query_list(self, header: list, params: UserParams): async def export_query_list(self, header: list, params: UserParams):
""" """
@ -309,13 +298,7 @@ class RoleDal(DalBase):
for menu in menus: for menu in menus:
obj.menus.append(menu) obj.menus.append(menu)
await self.flush(obj) await self.flush(obj)
if v_options: return await self.out_dict(obj, v_options, v_return_obj, v_schema)
obj = await self.get_data(obj.id, v_options=v_options)
if v_return_obj:
return obj
if v_schema:
return v_schema.from_orm(obj).dict()
return self.out_dict(await self.get_data(obj.id))
async def put_data( async def put_data(
self, self,
@ -338,13 +321,8 @@ class RoleDal(DalBase):
obj.menus.append(menu) obj.menus.append(menu)
continue continue
setattr(obj, key, value) setattr(obj, key, value)
await self.db.flush() await self.flush(obj)
await self.db.refresh(obj) return await self.out_dict(obj, None, v_return_obj, v_schema)
if v_return_obj:
return obj
if v_schema:
return v_schema.from_orm(obj).dict()
return self.out_dict(obj)
async def get_role_menu_tree(self, role_id: int): async def get_role_menu_tree(self, role_id: int):
role = await self.get_data(role_id, v_options=[joinedload(self.model.menus)]) role = await self.get_data(role_id, v_options=[joinedload(self.model.menus)])

View File

@ -34,8 +34,12 @@ class DictTypeDal(DalBase):
""" """
data = {} data = {}
options = [joinedload(self.model.details)] options = [joinedload(self.model.details)]
objs = await DictTypeDal(self.db).\ objs = await DictTypeDal(self.db).get_datas(
get_datas(limit=0, v_return_objs=True, v_options=options, dict_type=("in", dict_types)) limit=0,
v_return_objs=True,
v_options=options,
dict_type=("in", dict_types)
)
for obj in objs: for obj in objs:
if not obj: if not obj:
data[obj.dict_type] = [] data[obj.dict_type] = []
@ -163,6 +167,3 @@ class SettingsTabDal(DalBase):
tabs[item.config_key] = item.config_value tabs[item.config_key] = item.config_value
result[tab.tab_name] = tabs result[tab.tab_name] = tabs
return result return result

View File

@ -7,6 +7,7 @@
# @desc : 数据库 增删改查操作 # @desc : 数据库 增删改查操作
# sqlalchemy 查询操作https://segmentfault.com/a/1190000016767008 # sqlalchemy 查询操作https://segmentfault.com/a/1190000016767008
# sqlalchemy 查询操作(官方文档): https://www.osgeo.cn/sqlalchemy/orm/queryguide.html
# sqlalchemy 增删改操作https://www.osgeo.cn/sqlalchemy/tutorial/orm_data_manipulation.html#updating-orm-objects # sqlalchemy 增删改操作https://www.osgeo.cn/sqlalchemy/tutorial/orm_data_manipulation.html#updating-orm-objects
# SQLAlchemy lazy load和eager load: https://www.jianshu.com/p/dfad7c08c57a # SQLAlchemy lazy load和eager load: https://www.jianshu.com/p/dfad7c08c57a
# Mysql中内连接,左连接和右连接的区别总结:https://www.cnblogs.com/restartyang/articles/9080993.html # Mysql中内连接,左连接和右连接的区别总结:https://www.cnblogs.com/restartyang/articles/9080993.html
@ -28,11 +29,13 @@ from starlette import status
from core.logger import logger from core.logger import logger
from sqlalchemy.sql.selectable import Select from sqlalchemy.sql.selectable import Select
from typing import Any from typing import Any
from sqlalchemy.engine.result import ScalarResult
class DalBase: class DalBase:
# 倒叙
ORDER_FIELD = ["desc", "descending"]
def __init__(self, db: AsyncSession, model: Any, schema: Any, key_models: dict = None): def __init__(self, db: AsyncSession, model: Any, schema: Any, key_models: dict = None):
self.db = db self.db = db
self.model = model self.model = model
@ -64,7 +67,7 @@ class DalBase:
if data_id: if data_id:
sql = sql.where(self.model.id == data_id) sql = sql.where(self.model.id == data_id)
sql = self.add_filter_condition(sql, v_join_query, v_options, **kwargs) sql = self.add_filter_condition(sql, v_join_query, v_options, **kwargs)
if v_order and (v_order == "desc" or v_order == "descending"): if v_order and (v_order in self.ORDER_FIELD):
sql = sql.order_by(self.model.create_datetime.desc()) sql = sql.order_by(self.model.create_datetime.desc())
queryset = await self.db.execute(sql) queryset = await self.db.execute(sql)
data = queryset.scalars().unique().first() data = queryset.scalars().unique().first()
@ -105,20 +108,18 @@ class DalBase:
if not isinstance(v_start_sql, Select): if not isinstance(v_start_sql, Select):
v_start_sql = select(self.model).where(self.model.is_delete == False) v_start_sql = select(self.model).where(self.model.is_delete == False)
sql = self.add_filter_condition(v_start_sql, v_join_query, v_options, **kwargs) sql = self.add_filter_condition(v_start_sql, v_join_query, v_options, **kwargs)
if v_order_field and (v_order == "desc" or v_order == "descending"): if v_order_field and (v_order in self.ORDER_FIELD):
sql = sql.order_by(getattr(self.model, v_order_field).desc(), self.model.id.desc()) sql = sql.order_by(getattr(self.model, v_order_field).desc(), self.model.id.desc())
elif v_order_field: elif v_order_field:
sql = sql.order_by(getattr(self.model, v_order_field), self.model.id) sql = sql.order_by(getattr(self.model, v_order_field), self.model.id)
elif v_order == "desc" or v_order == "descending": elif v_order in self.ORDER_FIELD:
sql = sql.order_by(self.model.id.desc()) sql = sql.order_by(self.model.id.desc())
if limit != 0: if limit != 0:
sql = sql.offset((page - 1) * limit).limit(limit) sql = sql.offset((page - 1) * limit).limit(limit)
queryset = await self.db.execute(sql) queryset = await self.db.execute(sql)
if v_return_objs: if v_return_objs:
return queryset.scalars().unique().all() return queryset.scalars().unique().all()
if v_schema: return [await self.out_dict(i, v_schema=v_schema) for i in queryset.scalars().unique().all()]
return [v_schema.from_orm(i).dict() for i in queryset.scalars().unique().all()]
return [self.out_dict(i) for i in queryset.scalars().unique().all()]
async def get_count(self, v_join_query: dict = None, v_options: list = None, **kwargs): async def get_count(self, v_join_query: dict = None, v_options: list = None, **kwargs):
""" """
@ -145,13 +146,7 @@ class DalBase:
else: else:
obj = self.model(**data.dict()) obj = self.model(**data.dict())
await self.flush(obj) await self.flush(obj)
if v_options: return await self.out_dict(obj, v_options, v_return_obj, v_schema)
obj = await self.get_data(obj.id, v_options=v_options)
if v_return_obj:
return obj
if v_schema:
return v_schema.from_orm(obj).dict()
return self.out_dict(obj)
async def put_data( async def put_data(
self, self,
@ -174,11 +169,7 @@ class DalBase:
for key, value in obj_dict.items(): for key, value in obj_dict.items():
setattr(obj, key, value) setattr(obj, key, value)
await self.flush(obj) await self.flush(obj)
if v_return_obj: return await self.out_dict(obj, None, v_return_obj, v_schema)
return obj
if v_schema:
return v_schema.from_orm(obj).dict()
return self.out_dict(obj)
async def delete_datas(self, ids: List[int], v_soft: bool = False, **kwargs): async def delete_datas(self, ids: List[int], v_soft: bool = False, **kwargs):
""" """
@ -189,9 +180,11 @@ class DalBase:
""" """
if v_soft: if v_soft:
await self.db.execute( await self.db.execute(
update(self.model) update(self.model).where(self.model.id.in_(ids)).values(
.where(self.model.id.in_(ids)) delete_datetime=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
.values(delete_datetime=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), is_delete=True, **kwargs) is_delete=True,
**kwargs
)
) )
else: else:
await self.db.execute(delete(self.model).where(self.model.id.in_(ids))) await self.db.execute(delete(self.model).where(self.model.id.in_(ids)))
@ -209,7 +202,7 @@ class DalBase:
foreign_key = self.key_models.get(key) foreign_key = self.key_models.get(key)
if foreign_key and foreign_key.get("model"): if foreign_key and foreign_key.get("model"):
# 当外键模型在查询模型中存在多个外键时则需要添加onclause属性 # 当外键模型在查询模型中存在多个外键时则需要添加onclause属性
sql = sql.join(foreign_key.get("model"), onclause=foreign_key.get("onclause")) sql = sql.join(foreign_key.get("model"), onclause=foreign_key.get("onclause", None))
for v_key, v_value in value.items(): for v_key, v_value in value.items():
if v_value is not None and v_value != "": if v_value is not None and v_value != "":
v_attr = getattr(foreign_key.get("model"), v_key, None) v_attr = getattr(foreign_key.get("model"), v_key, None)
@ -250,7 +243,7 @@ class DalBase:
sql = sql.where(or_(i for i in value[1])) sql = sql.where(or_(i for i in value[1]))
elif value[0] == "in": elif value[0] == "in":
sql = sql.where(attr.in_(value[1])) sql = sql.where(attr.in_(value[1]))
elif value[0] == "between": elif value[0] == "between" and len(value[1]) == 2:
sql = sql.where(attr.between(value[1][0], value[1][1])) sql = sql.where(attr.between(value[1][0], value[1][1]))
elif value[0] == "month": elif value[0] == "month":
sql = sql.where(func.date_format(attr, "%Y-%m") == value[1]) sql = sql.where(func.date_format(attr, "%Y-%m") == value[1])
@ -272,10 +265,19 @@ class DalBase:
if obj: if obj:
await self.db.refresh(obj) await self.db.refresh(obj)
def out_dict(self, data: Any): async def out_dict(self, obj: Any, v_options: list = None, v_return_obj: bool = False, v_schema: Any = None):
""" """
序列化 序列化
:param data: :param obj:
:param v_options: 指示应使用select在预加载中加载给定的属性
:param v_return_obj: 是否返回对象
:param v_schema: 指定使用的序列化对象
:return: :return:
""" """
return self.schema.from_orm(data).dict() if v_options:
obj = await self.get_data(obj.id, v_options=v_options)
if v_return_obj:
return obj
if v_schema:
return v_schema.from_orm(obj).dict()
return self.schema.from_orm(obj).dict()

View File

@ -24,11 +24,18 @@ class QueryParams:
self.v_order = params.v_order self.v_order = params.v_order
self.v_order_field = params.v_order_field self.v_order_field = params.v_order_field
def dict(self) -> dict: def dict(self, exclude: List[str] = None) -> dict:
return self.__dict__ result = copy.deepcopy(self.__dict__)
if exclude:
for item in exclude:
try:
del result[item]
except KeyError:
pass
return result
def to_count(self) -> dict: def to_count(self, exclude: List[str] = None) -> dict:
params = copy.deepcopy(self.__dict__) params = self.dict(exclude=exclude)
del params["page"] del params["page"]
del params["limit"] del params["limit"]
del params["v_order"] del params["v_order"]