diff --git a/api/manage_project.py b/api/manage_project.py index ed5d3ea..4c1d669 100644 --- a/api/manage_project.py +++ b/api/manage_project.py @@ -6,30 +6,159 @@ from fastapi import HTTPException, Response, Depends, APIRouter from typing import Optional, Annotated from datetime import datetime, timedelta from jose import JWTError, jwt +from sqlalchemy import delete from sqlmodel import select -from models import Tenant, User, Project +from models import Tenant, User, Project, ProjectUserLink from dependencies import * from typing import List router = APIRouter() +TenantRole = 1 -#列举所有项目 +# 列举所有项目 @router.get("/api/s1/project") -async def get_project(response:Response, session: SessionDep): +async def get_project(response: Response, session: SessionDep): projects = session.query(Project).filter().all() if not projects: raise HTTPException(status_code=404, detail="Project not found") return {"projects": projects} -#新增和修改项目 + +# 新增和修改项目 # @router.post("/api/s1/project") -# async def create_project(project_detail:dict, session: SessionDep): -# project = { -# "name": project_detail["name"], +# async def create_project(data:dict, session: SessionDep, current_user: User = Depends(get_current_user)): +# if current_user.role != TenantRole: +# raise HTTPException(status_code=403, detail="Only tenant users can add or update project") +# name = data["name"] +# requirement = data["requirement"] +# start_time = data["start_time"] +# deadline = data["deadline"] +# estimators = data["estimator"] +# auditors = data["auditor"] # -# } -# return {"newProject": project} \ No newline at end of file +# #验证是否缺少必要参数 +# if not name or not requirement or not start_time or not deadline or not estimators or not auditors: +# raise HTTPException(status_code=400, detail="Need more details") +# +# #验证开始时间是否早于结束时间 +# if datetime.fromisoformat(start_time) > datetime.fromisoformat(deadline): +# raise HTTPException(status_code=400, detail="Start time must be before deadline") +# +# #验证评估审核员是否存在 +# query_estimators = select(User).where(User.username.in_(estimators), User.tenant_id == current_user.tenant_id) +# users_estimators = session.exec(query_estimators).all() +# query_auditors = select(User).where(User.username.in_(auditors), User.tenant_id == current_user.tenant_id) +# users_auditors = session.exec(query_auditors).all() +# # 提取出所有查询到的 +# existing_estimators = {user.username for user in users_estimators} +# existing_auditors = {user.username for user in users_auditors} +# +# # 验证是否所有的username都存在于数据库中 +# missing_usernames = (set(auditors) | set(estimators)) - existing_estimators - existing_auditors +# +# if missing_usernames: +# raise HTTPException(status_code=404, detail=f"Missing usernames:{missing_usernames}") +# +# newProject = Project( +# name=name, +# requirement=requirement, +# start_time=start_time, +# deadline=deadline, +# owner_id = current_user.tenant_id, +# ) +# session.add(newProject) +# session.commit() +# session.refresh(newProject) +# +# return {"newProject": newProject, +# "refreshProject.id": newProject.id, +# } +@router.post("/api/s1/project") +async def create_project(data: dict, session: SessionDep): + project_id = data.get("project_id") + name = data["name"] + requirement = data["requirement"] + start_time_str = data["start_time"] + deadline_str = data["deadline"] + estimators = data["estimators"] + auditors = data["auditors"] + + # 验证是否缺少必要参数 + if not name or not requirement or not start_time_str or not deadline_str: + raise HTTPException(status_code=400, detail="Need more name/requirement/start_time/deadline") + + # 验证开始时间是否早于结束时间 + start_time = datetime.strptime(start_time_str, "%Y-%m-%d") + deadline = datetime.strptime(deadline_str, "%Y-%m-%d") + if start_time > deadline: + raise HTTPException(status_code=400, detail="Start time must be before deadline") + + # 验证评估审核员是否存在 + query_estimators = select(User).where(User.username.in_(estimators)) + users_estimators = session.exec(query_estimators).all() + query_auditors = select(User).where(User.username.in_(auditors)) + users_auditors = session.exec(query_auditors).all() + # 提取出所有查询到的 + existing_estimators = {user.username for user in users_estimators} + existing_auditors = {user.username for user in users_auditors} + + # 验证是否所有的username都存在于数据库中 + missing_usernames = (set(auditors) | set(estimators)) - existing_estimators - existing_auditors + + if missing_usernames: + raise HTTPException(status_code=404, detail=f"Missing usernames:{missing_usernames}") + + # 更新项目还是新增项目 + if project_id: + # 查找现有项目 + project = session.get(Project, project_id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + # 更新项目内容 + project.name = name + project.requirement = requirement + project.start_time = start_time + project.deadline = deadline + else: + # 新增项目 + project = Project( + name=name, + requirement=requirement, + start_time=start_time, + deadline=deadline, + owner_id=1 # 假设owner_id是1 + ) + session.add(project) + + # 处理项目和用户的关联 + # 先清除现有的关联 + # 生成删除语句并执行 + stmt = delete(ProjectUserLink).where(ProjectUserLink.project_id == project.id) + session.execute(stmt) + session.commit() # 提交事务 + + # 重新建立与评估员和审核员的关系 + for username in estimators: + user = next((user for user in users_estimators if user.username == username), None) + if user: + project_user_link = ProjectUserLink(project_id=project.id, user_id=user.id) + session.add(project_user_link) + + for username in auditors: + user = next((user for user in users_auditors if user.username == username), None) + if user: + project_user_link = ProjectUserLink(project_id=project.id, user_id=user.id) + session.add(project_user_link) + + # 提交事务 + session.commit() + session.refresh(project) + + return {"newProject": project, + "refreshProject.id": project.id, + } diff --git a/dependencies.py b/dependencies.py index d5e01e7..f6c480d 100644 --- a/dependencies.py +++ b/dependencies.py @@ -4,10 +4,12 @@ # @FileName: dependencies.py # @Software: PyCharm from typing import Annotated -from fastapi import Depends +from fastapi import Depends, Request, HTTPException, Cookie, Response +from jose import jwt, JWTError from database import engine -from sqlmodel import Session +from sqlmodel import Session, select from config import Settings +from models import User def get_session(): @@ -22,3 +24,27 @@ def get_settings(): SessionDep = Annotated[Session, Depends(get_session)] SettingsDep = get_settings() + + +def get_current_user(response: Response, session_token: Annotated[str | None, Cookie()] = None, db: SessionDep = None, + settings: SettingsDep = SettingsDep): + if not session_token: + response.set_cookie(key="session_token", value="", httponly=True) + raise HTTPException(status_code=401, detail="Not authenticated", ) + + try: + payload = jwt.decode(session_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + user_id = payload.get("id") + if user_id is None: + response.set_cookie(key="session_token", value="", httponly=True) + raise HTTPException(status_code=401, detail="Invalid token") + except JWTError: + response.set_cookie(key="session_token", value="", httponly=True) + raise HTTPException(status_code=401, detail="Invalid token") + + user = db.exec(select(User).where(User.id == user_id)).first() + if not user: + response.set_cookie(key="session_token", value="", httponly=True) + raise HTTPException(status_code=401, detail="User not found") + + return user