"""Base repository class""" from typing import Generic, Type, TypeVar, Optional, List from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update, delete from sqlalchemy.orm import selectinload from app.db.base import Base ModelType = TypeVar("ModelType", bound=Base) class BaseRepository(Generic[ModelType]): """Base repository with common CRUD operations""" def __init__(self, model: Type[ModelType], session: AsyncSession): self.model = model self.session = session async def create(self, **kwargs) -> ModelType: """Create a new record""" instance = self.model(**kwargs) self.session.add(instance) await self.session.flush() return instance async def get(self, id: int) -> Optional[ModelType]: """Get a record by ID""" result = await self.session.execute( select(self.model).where(self.model.id == id) ) return result.scalar_one_or_none() async def get_multi( self, skip: int = 0, limit: int = 100, **filters ) -> List[ModelType]: """Get multiple records with pagination""" query = select(self.model) # Apply filters for key, value in filters.items(): if hasattr(self.model, key): query = query.where(getattr(self.model, key) == value) query = query.offset(skip).limit(limit) result = await self.session.execute(query) return result.scalars().all() async def update(self, id: int, **kwargs) -> Optional[ModelType]: """Update a record""" await self.session.execute( update(self.model) .where(self.model.id == id) .values(**kwargs) ) return await self.get(id) async def delete(self, id: int) -> bool: """Delete a record""" result = await self.session.execute( delete(self.model).where(self.model.id == id) ) return result.rowcount > 0