diff --git a/src/dao/sync_config.py b/src/dao/sync_config.py index 7124636..bc2519f 100644 --- a/src/dao/sync_config.py +++ b/src/dao/sync_config.py @@ -3,7 +3,7 @@ from sqlalchemy.exc import NoResultFound from src.do.sync_config import SyncBranchMapping, SyncRepoMapping, LogDO from .mysql_ao import MysqlAO from src.utils.base import Singleton -from src.dto.sync_config import AllRepoDTO, GetBranchDTO, SyncRepoDTO, SyncBranchDTO, RepoDTO +from src.dto.sync_config import AllRepoDTO, GetBranchDTO, SyncRepoDTO, SyncBranchDTO, RepoDTO, BranchDTO from typing import List from src.do.sync_config import SyncDirect, SyncType @@ -143,17 +143,18 @@ class SyncBranchDAO(BaseDAO, metaclass=Singleton): def __init__(self, *args, **kwargs): super().__init__(SyncBranchMapping, *args, **kwargs) - async def create_branch(self, dto: SyncBranchDTO, repo_id: int) -> SyncBranchDTO: + async def create_branch(self, dto: SyncBranchDTO, repo_id: int) -> BranchDTO: async with self._async_session() as session: async with session.begin(): do = SyncBranchMapping(**dto.dict(), repo_id=repo_id) session.add(do) - data = SyncBranchDTO( + await session.commit() + data = BranchDTO( + id=do.id, enable=do.enable, internal_branch_name=do.internal_branch_name, external_branch_name=do.external_branch_name ) - await session.commit() return data async def get_sync_branch(self, repo_id: int, page_number: int, page_size: int, create_sort: bool) -> List[GetBranchDTO]: diff --git a/src/dto/sync_config.py b/src/dto/sync_config.py index a192d09..6b3f4ac 100644 --- a/src/dto/sync_config.py +++ b/src/dto/sync_config.py @@ -17,6 +17,13 @@ class SyncBranchDTO(BaseModel): external_branch_name: str = Field(..., description="外部仓库分支名") +class BranchDTO(BaseModel): + id: int = Field(..., description="分支id") + enable: bool = Field(..., description="是否启用分支同步") + internal_branch_name: str = Field(..., description="内部仓库分支名") + external_branch_name: str = Field(..., description="外部仓库分支名") + + class RepoDTO(BaseModel): enable: bool = Field(..., description="是否启用同步") repo_name: str = Field(..., description="仓库名称") diff --git a/src/service/sync_config.py b/src/service/sync_config.py index 90eb2dd..e51f28c 100644 --- a/src/service/sync_config.py +++ b/src/service/sync_config.py @@ -2,7 +2,7 @@ import re from typing import List, Union, Optional, Dict from .service import Service from src.dao.sync_config import SyncBranchDAO, SyncRepoDAO, LogDAO -from src.dto.sync_config import SyncBranchDTO, SyncRepoDTO, RepoDTO, AllRepoDTO, GetBranchDTO, LogDTO +from src.dto.sync_config import SyncBranchDTO, SyncRepoDTO, RepoDTO, AllRepoDTO, GetBranchDTO, LogDTO, BranchDTO from src.do.sync_config import SyncDirect, SyncType from src.base.status_code import Status, SYNCException from src.utils.sync_log import log_path @@ -37,7 +37,7 @@ class SyncService(Service): raise SYNCException(Status.BRANCH_EXISTS) return repo.id - async def create_branch(self, dto: SyncBranchDTO, repo_id: int) -> Optional[SyncBranchDTO]: + async def create_branch(self, dto: SyncBranchDTO, repo_id: int) -> Optional[BranchDTO]: branch = await self.sync_branch_dao.create_branch(dto, repo_id=repo_id) return branch