1
0
2025-06-01 21:39:53 +09:00

65 lines
2.0 KiB
Python

"""Base repository class"""
from typing import Generic, Type, TypeVar, Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.orm import selectinload
from app.db.base import Base
ModelType = TypeVar("ModelType", bound=Base)
class BaseRepository(Generic[ModelType]):
"""Base repository with common CRUD operations"""
def __init__(self, model: Type[ModelType], session: AsyncSession):
self.model = model
self.session = session
async def create(self, **kwargs) -> ModelType:
"""Create a new record"""
instance = self.model(**kwargs)
self.session.add(instance)
await self.session.flush()
return instance
async def get(self, id: int) -> Optional[ModelType]:
"""Get a record by ID"""
result = await self.session.execute(
select(self.model).where(self.model.id == id)
)
return result.scalar_one_or_none()
async def get_multi(
self,
skip: int = 0,
limit: int = 100,
**filters
) -> List[ModelType]:
"""Get multiple records with pagination"""
query = select(self.model)
# Apply filters
for key, value in filters.items():
if hasattr(self.model, key):
query = query.where(getattr(self.model, key) == value)
query = query.offset(skip).limit(limit)
result = await self.session.execute(query)
return result.scalars().all()
async def update(self, id: int, **kwargs) -> Optional[ModelType]:
"""Update a record"""
await self.session.execute(
update(self.model)
.where(self.model.id == id)
.values(**kwargs)
)
return await self.get(id)
async def delete(self, id: int) -> bool:
"""Delete a record"""
result = await self.session.execute(
delete(self.model).where(self.model.id == id)
)
return result.rowcount > 0