diff --git a/api/manage_project.py b/api/manage_project.py index ea83f57..a8a51d7 100644 --- a/api/manage_project.py +++ b/api/manage_project.py @@ -3,6 +3,7 @@ # @FileName: manage_project.py # @Software: PyCharm from datetime import datetime +from typing import List from fastapi import APIRouter from sqlalchemy import delete @@ -56,7 +57,7 @@ async def create_project(data: dict, session: SessionDep, current_user: User = D if current_user.role != 1: raise HTTPException(status_code=403, detail="Only Tenant admin users can add or update projects.") - project_id = data.get("project_id") + # project_id = data.get("project_id") name = data.get("name") requirement = data.get("requirement") start_time_str = data.get("start_time") @@ -82,39 +83,26 @@ async def create_project(data: dict, session: SessionDep, current_user: User = D # 验证是否有传入评估/审核员 if not estimators or not auditors: raise HTTPException(status_code=400, detail="Need more estimators or auditors") + + users: List[User] = [] # 验证评估审核员是否存在 - query_estimators = select(User).where(User.username.in_(estimators)) - users_estimators = session.exec(query_estimators).all() - query_auditors = select(User).where(User.username.in_(auditors)) - users_auditors = session.exec(query_auditors).all() - # 提取出所有查询到的 - existing_estimators = {user.username for user in users_estimators} - existing_auditors = {user.username for user in users_auditors} + for username in estimators + auditors: + query_estimator = select(User).where(User.username == username) + if user := session.exec(query_estimator).first(): + users.append(user) - # 验证是否所有的username都存在于数据库中 - missing_usernames = (set(auditors) | set(estimators)) - existing_estimators - existing_auditors - - if missing_usernames: - raise HTTPException(status_code=404, detail=f"Missing usernames:{missing_usernames}") + project = session.exec(select(Project).where(Project.name == name)).first() + if project and project.owner_id != current_user.tenant_id: + raise HTTPException(status_code=403, detail="You do not have permission to modify this project.") # 更新项目还是新增项目 - if project_id: - # 查找现有项目 - project = session.get(Project, project_id) - if not project: - raise HTTPException(status_code=404, detail="Project not found") - + if project: # 更新项目内容 project.name = name project.requirement = requirement project.start_time = start_time project.deadline = deadline else: - # 新增项目 - exist_project = session.exec(select(Project).where(Project.name == name)).first() - if exist_project: - raise HTTPException(status_code=404, detail="Project already exists") - project = Project( name=name, requirement=requirement, @@ -122,27 +110,14 @@ async def create_project(data: dict, session: SessionDep, current_user: User = D deadline=deadline, owner_id=current_user.tenant_id, ) - session.add(project) # 处理项目和用户的关联 # 先清除现有的关联 # 生成删除语句并执行 - stmt = delete(ProjectUserLink).where(ProjectUserLink.project_id == project.id) - session.execute(stmt) - session.commit() # 提交事务 - - # 重新建立与评估员和审核员的关系 - for username in estimators: - user = next((user for user in users_estimators if user.username == username), None) - if user: - project_user_link = ProjectUserLink(project_id=project.id, user_id=user.id) - session.add(project_user_link) - - for username in auditors: - user = next((user for user in users_auditors if user.username == username), None) - if user: - project_user_link = ProjectUserLink(project_id=project.id, user_id=user.id) - session.add(project_user_link) + project.users = [] + project.users = users + session.add(project) + session.commit() # 提交事务 session.commit() @@ -155,11 +130,12 @@ async def create_project(data: dict, session: SessionDep, current_user: User = D # 删除项目 @router.delete("/api/s1/project") -async def delete_project(data: dict, session: SessionDep, current_user: User = Depends(get_current_user)): +async def delete_project(name: str, session: SessionDep, current_user: User = Depends(get_current_user)): if current_user.role != 1: raise HTTPException(status_code=403, detail="Only Tenant admin users can delete projects.") - project_name = data.get("name") + # project_name = data.get("name") + project_name = name if not project_name: raise HTTPException(status_code=400, detail="Project name is required")