from sqlmodel import SQLModel, select
from sqlalchemy.orm import InstrumentedAttribute
from sqlmodel.sql.expression import SelectOfScalar
from sqlmodel.ext.asyncio.session import AsyncSession
from typing import Any, Protocol, Unpack, TypeVar, TypedDict, Generic, NotRequired, Literal, Self, ClassVar, Type, Optional
from pydantic import BaseModel
from collections.abc import Sequence
from arbor_imago import custom_types, models
from arbor_imago.schemas.pagination import Pagination
from arbor_imago.schemas.order_by import OrderBy
TCreateModel = TypeVar('TCreateModel', bound=BaseModel)
TCreateModel_contra = TypeVar(
'TCreateModel_contra', bound=BaseModel, contravariant=True)
TCreateModel_co = TypeVar('TCreateModel_co', bound=BaseModel, covariant=True)
TUpdateModel = TypeVar('TUpdateModel', bound=BaseModel)
TUpdateModel_contra = TypeVar(
'TUpdateModel_contra', bound=BaseModel, contravariant=True)
TUpdateModel_co = TypeVar('TUpdateModel_co', bound=BaseModel, covariant=True)
TOrderBy_co = TypeVar('TOrderBy_co', bound=str, covariant=True)
[docs]
class CRUDParamsBase(TypedDict):
session: AsyncSession
authorized_user_id: Optional[custom_types.User.id]
admin: bool
[docs]
class WithId(Generic[custom_types.TId], TypedDict):
id: custom_types.TId
[docs]
class WithModelInst(Generic[models.TModel_contra], TypedDict):
model_inst: models.TModel_contra
[docs]
class CreateParams(Generic[TCreateModel_contra], CRUDParamsBase):
create_model: TCreateModel_contra
[docs]
class ReadParams(Generic[custom_types.TId], CRUDParamsBase, WithId[custom_types.TId]):
pass
[docs]
class ReadManyBase(Generic[models.TModel, TOrderBy_co], TypedDict):
pagination: Pagination
order_bys: NotRequired[list[OrderBy[TOrderBy_co]]]
query: NotRequired[SelectOfScalar[models.TModel] | None]
[docs]
class ReadManyParams(Generic[models.TModel, TOrderBy_co], CRUDParamsBase, ReadManyBase[models.TModel, TOrderBy_co]):
pass
[docs]
class UpdateParams(Generic[custom_types.TId, TUpdateModel_contra], CRUDParamsBase, WithId[custom_types.TId]):
update_model: TUpdateModel_contra
[docs]
class DeleteParams(Generic[custom_types.TId], CRUDParamsBase, WithId[custom_types.TId]):
pass
CheckAuthorizationExistingOperation = Literal['read', 'update', 'delete']
[docs]
class CheckAuthorizationExistingParams(Generic[models.TModel_contra, custom_types.TId], CRUDParamsBase, WithId[custom_types.TId], WithModelInst[models.TModel_contra]):
operation: CheckAuthorizationExistingOperation
[docs]
class CheckAuthorizationNewParams(Generic[TCreateModel_contra], CreateParams[TCreateModel_contra]):
pass
[docs]
class CheckAuthorizationReadManyParams(Generic[models.TModel, TOrderBy_co], ReadManyParams[models.TModel, TOrderBy_co]):
pass
[docs]
class CheckValidationDeleteParams(Generic[custom_types.TId], DeleteParams[custom_types.TId]):
pass
[docs]
class CheckValidationPatchParams(Generic[models.TModel, custom_types.TId, TUpdateModel_contra], UpdateParams[custom_types.TId, TUpdateModel_contra], WithModelInst[models.TModel]):
pass
[docs]
class CheckValidationPostParams(Generic[TCreateModel_contra], CreateParams[TCreateModel_contra]):
pass
[docs]
class HasModel(Protocol[models.TModel_co]):
_MODEL: Type[models.TModel_co]
[docs]
class HasModelInstFromCreateModel(Protocol[models.TModel_co, TCreateModel_contra]):
[docs]
@classmethod
def model_inst_from_create_model(cls, create_model: TCreateModel_contra) -> models.TModel_co:
...
[docs]
class HasModelId(Protocol[models.TModel_contra, custom_types.TId_co]):
[docs]
@classmethod
def model_id(cls, inst: models.TModel_contra) -> custom_types.TId_co:
...
[docs]
class HasBuildSelectById(Protocol[models.TModel, custom_types.TId_contra]):
@classmethod
def _build_select_by_id(cls, id: custom_types.TId_contra) -> SelectOfScalar[models.TModel]:
...
[docs]
class SimpleIdModelService(
Generic[models.TSimpleModel, custom_types.TSimpleId],
HasModel[models.TSimpleModel],
HasModelId[models.TSimpleModel, custom_types.TSimpleId],
HasBuildSelectById[models.TSimpleModel, custom_types.TSimpleId],
):
_MODEL: Type[models.TSimpleModel]
[docs]
@classmethod
def model_id(cls, inst: models.TSimpleModel) -> custom_types.TSimpleId:
return inst.id # type: ignore
@classmethod
def _build_select_by_id(cls, id: custom_types.TSimpleId) -> SelectOfScalar[models.TSimpleModel]:
return select(cls._MODEL).where(cls._MODEL.id == id)
[docs]
class ServiceError(Exception):
error_message: str
def __init__(self, error_message: str):
self.error_message = error_message
super().__init__(error_message)
[docs]
class NotFoundError(ValueError, ServiceError):
def __init__(self, model: Type[models.Model], id: custom_types.Id):
self.error_message = NotFoundError.not_found_message(model, id)
super().__init__(self.error_message)
[docs]
@staticmethod
def not_found_message(model: Type[models.Model], id: custom_types.Id) -> str:
return model.__name__ + ' with id `' + str(id) + '` not found'
[docs]
class AlreadyExistsError(ServiceError):
def __init__(self, model: Type[models.Model], id: custom_types.Id):
self.error_message = model.__name__ + \
' with id `' + str(id) + '` already exists'
super().__init__(self.error_message)
[docs]
class NotAvailableError(ServiceError):
pass
[docs]
class UnauthorizedError(ServiceError):
pass
[docs]
class Service(
Generic[
models.TModel,
custom_types.TId,
TCreateModel,
TUpdateModel,
TOrderBy_co
],
HasModel[models.TModel],
HasModelInstFromCreateModel[models.TModel, TCreateModel],
HasModelId[models.TModel, custom_types.TId],
HasBuildSelectById[models.TModel, custom_types.TId],
):
[docs]
@classmethod
async def fetch_one(cls, session: AsyncSession, query: SelectOfScalar[models.TModel]) -> models.TModel | None:
return (await session.exec(query)).one_or_none()
[docs]
@classmethod
async def fetch_many(cls, session: AsyncSession, pagination: Pagination, order_bys: list[OrderBy[TOrderBy_co]] = [], query: SelectOfScalar[models.TModel] | None = None) -> Sequence[models.TModel]:
if query is None:
query = select(cls._MODEL)
query = cls.build_order_by(query, order_bys)
query = query.offset(pagination.offset).limit(pagination.limit)
return (await session.exec(query)).all()
[docs]
@classmethod
async def fetch_by_id(cls, session: AsyncSession, id: custom_types.TId) -> models.TModel | None:
query = cls._build_select_by_id(id)
return await cls.fetch_one(session, query)
[docs]
@classmethod
async def fetch_by_id_with_exception(cls, session: AsyncSession, id: custom_types.TId) -> models.TModel:
inst = await cls.fetch_by_id(session, id)
if inst is None:
raise NotFoundError(cls._MODEL, id)
return inst
[docs]
@classmethod
def build_order_by(cls, query: SelectOfScalar[models.TModel], order_by: list[OrderBy[TOrderBy_co]]):
for order in order_by:
field: InstrumentedAttribute = getattr(cls, order.field)
if order.ascending:
query = query.order_by(field.asc())
else:
query = query.order_by(field.desc())
return query
@classmethod
async def _check_authorization_existing(cls, params: CheckAuthorizationExistingParams[models.TModel, custom_types.TId]) -> None:
"""Check if the user is authorized to access the instance"""
pass
@classmethod
async def _check_authorization_new(cls, params: CheckAuthorizationNewParams[TCreateModel]) -> None:
"""Check if the user is authorized to create a new instance"""
pass
@classmethod
async def _check_authorization_read_many(cls, params: CheckAuthorizationReadManyParams[models.TModel, TOrderBy_co]) -> None:
"""Check if the user is authorized to read many instances"""
pass
@classmethod
async def _check_validation_delete(cls, params: CheckValidationDeleteParams[custom_types.TId]) -> None:
"""Check if the user is authorized to delete the instance"""
pass
@classmethod
async def _check_validation_patch(cls, params: CheckValidationPatchParams[models.TModel, custom_types.TId, TUpdateModel]) -> None:
"""Check if the user is authorized to update the instance"""
pass
@classmethod
async def _check_validation_post(cls, params: CheckValidationPostParams[TCreateModel]) -> None:
"""Check if the user is authorized to create a new instance"""
pass
[docs]
@classmethod
async def read(cls, params: ReadParams[custom_types.TId]) -> models.TModel:
"""Used in conjunction with API endpoints, raises exceptions while trying to get an instance of the model by ID"""
model_inst = await cls.fetch_by_id_with_exception(params['session'], params['id'])
await cls._check_authorization_existing(
{**params, 'model_inst': model_inst, 'operation': 'read'})
return model_inst
[docs]
@classmethod
async def read_many(cls, params: ReadManyParams[models.TModel, TOrderBy_co]) -> Sequence[models.TModel]:
"""Used in conjunction with API endpoints, raises exceptions while trying to get a list of instances of the model"""
await cls._check_authorization_read_many(params)
kwargs = {}
if 'order_bys' in params:
kwargs['order_bys'] = params['order_bys']
if 'query' in params:
kwargs['query'] = params['query']
return await cls.fetch_many(params['session'], params['pagination'], **kwargs)
[docs]
@classmethod
async def create(cls, params: CreateParams[TCreateModel]) -> models.TModel:
"""Used in conjunction with API endpoints, raises exceptions while trying to create a new instance of the model"""
await cls._check_authorization_new(params)
await cls._check_validation_post(params)
model_inst = cls.model_inst_from_create_model(params['create_model'])
params['session'].add(model_inst)
await params['session'].commit()
await params['session'].refresh(model_inst)
return model_inst
[docs]
@classmethod
def model_inst_from_create_model(cls, create_model: TCreateModel) -> models.TModel:
return cls._MODEL(**create_model.model_dump())
[docs]
@classmethod
async def update(cls, params: UpdateParams[custom_types.TId, TUpdateModel]) -> models.TModel:
"""Used in conjunction with API endpoints, raises exceptions while trying to update an instance of the model by ID"""
# when changing this, be sure to update the services/gallery.py file as well
model_inst = await cls.fetch_by_id_with_exception(params['session'], params['id'])
await cls._check_authorization_existing({
'session': params['session'],
'model_inst': model_inst,
'operation': 'read',
'id': params['id'],
'admin': params['admin'],
'authorized_user_id': params['authorized_user_id']
})
await cls._check_validation_patch({**params, 'model_inst': model_inst})
await cls._update_model_inst(model_inst, params['update_model'])
await params['session'].commit()
await params['session'].refresh(model_inst)
return model_inst
@classmethod
async def _update_model_inst(cls, inst: models.TModel, update_model: TUpdateModel) -> None:
"""Update an instance of the model from the update model (TUpdateModel)"""
inst.sqlmodel_update(update_model.model_dump(exclude_unset=True))
[docs]
@classmethod
async def delete(cls, params: DeleteParams[custom_types.TId]) -> None:
"""Used in conjunction with API endpoints, raises exceptions while trying to delete an instance of the model by ID"""
model_inst = await cls.fetch_by_id_with_exception(params['session'], params['id'])
await cls._check_authorization_existing({
'session': params['session'],
'operation': 'delete',
'id': params['id'],
'model_inst': model_inst,
'admin': params['admin'],
'authorized_user_id': params['authorized_user_id']
})
await cls._check_validation_delete(params)
await params['session'].delete(model_inst)
await params['session'].commit()
'''
def generate(self) -> TId:
if len(self.fields) == 1:
return str(uuid.uuid4())
return tuple(str(uuid.uuid4()) for _ in self.fields)
@classmethod
def export_plural_to_dict(cls, items: collections.abc.Iterable[typing.Self]) -> dict[TId, typing.Self]:
return {item._id: item for item in items}
# @classmethod
# def _build_conditions(cls, filters: dict[str, typing.Any]):
# conditions = []
# for key in filters:
# value = filters[key]
# field: InstrumentedAttribute = getattr(cls, key)
# if isinstance(value, list):
# conditions.append(field.in_(value))
# else:
# conditions.append(field == value)
# return and_(*conditions)
'''