diff --git a/api/login_reg.py b/api/login_reg.py index 259c21a..0937b19 100644 --- a/api/login_reg.py +++ b/api/login_reg.py @@ -3,15 +3,13 @@ # @Author : 河瞬 # @FileName: login_reg.py # @Software: PyCharm -from fastapi import HTTPException, Response, Depends, APIRouter -from typing import Optional, Annotated from datetime import datetime, timedelta -from jose import JWTError, jwt +from typing import Optional -from sqlmodel import select +from fastapi import APIRouter -from models import Tenant, User, Project from dependencies import * +from models import Tenant router = APIRouter() @@ -53,14 +51,15 @@ async def login(response: Response, user_data: dict, session: SessionDep): @router.post("/api/s1/register") async def register(data: dict, session: SessionDep): - if session.exec(select(Tenant).where(Tenant.name == data['name'])): + if session.exec(select(Tenant).where(Tenant.name == data['name'])).first(): raise HTTPException(status_code=409, detail="租户名已存在") - if session.exec(select(User).where(User.username == data['username'])): + if session.exec(select(User).where(User.username == data['username'])).first(): raise HTTPException(status_code=409, detail="用户名已存在") tenant = Tenant(name=data['name']) user = User(username=data['username'], password=data['password'], role=1, tenant=tenant) session.add(tenant) session.add(user) + session.commit() session.close() return {"detail": "注册成功"} diff --git a/api/manage_user.py b/api/manage_user.py index 2024460..9900b1c 100644 --- a/api/manage_user.py +++ b/api/manage_user.py @@ -4,9 +4,10 @@ # @FileName: manage_user.py # @Software: PyCharm from fastapi import HTTPException, APIRouter, Depends, Request -from sqlmodel import select, Session -from models import User, Tenant +from sqlmodel import select + from dependencies import SessionDep, get_current_user +from models import User router = APIRouter() @@ -31,6 +32,9 @@ async def add_or_update_user(data: dict, session: SessionDep, current_user: User username = data.get("username") password = data.get("password") role = data.get("role") + if role not in ["auditor", "estimator"]: + raise HTTPException(status_code=400, detail="Invalid role") + role = 2 if role == "estimator" else 3 if not username or not role: raise HTTPException(status_code=400, detail="Username and role are required") @@ -63,7 +67,7 @@ async def delete_user(username: str, session: SessionDep, current_user: User = D # username = data.get("username") if not username: - raise HTTPException(status_code=400, detail="Username is required") + raise HTTPException(status_code=422, detail="Username is required") user = session.exec(select(User).where(User.username == username, User.tenant_id == current_user.tenant_id)).first() if not user: diff --git a/test/.env b/test/.env index 5c3af5a..9ad1c88 100644 --- a/test/.env +++ b/test/.env @@ -1,3 +1,3 @@ ALGORITHM=HS256 -DATABASE_URL=sqlite:///:memory: +DATABASE_URL=sqlite:///test.sqlite3 SECRET_KEY=your_secret_key