用户管理完成。单元测试没成功。
This commit is contained in:
parent
f84a035640
commit
de10d61129
@ -16,5 +16,20 @@
|
||||
</library>
|
||||
</libraries>
|
||||
</data-source>
|
||||
<data-source source="LOCAL" name="test [2]" uuid="e913a282-2350-4d44-b0c2-132ec4d2c3d7">
|
||||
<driver-ref>sqlite.xerial</driver-ref>
|
||||
<synchronize>true</synchronize>
|
||||
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
|
||||
<jdbc-url>jdbc:sqlite:E:\pyprojects\CostEvalPlatform\test\test.db</jdbc-url>
|
||||
<working-dir>$ProjectFileDir$</working-dir>
|
||||
<libraries>
|
||||
<library>
|
||||
<url>file://$APPLICATION_CONFIG_DIR$/jdbc-drivers/Xerial SQLiteJDBC/3.45.1/org/xerial/sqlite-jdbc/3.45.1.0/sqlite-jdbc-3.45.1.0.jar</url>
|
||||
</library>
|
||||
<library>
|
||||
<url>file://$APPLICATION_CONFIG_DIR$/jdbc-drivers/Xerial SQLiteJDBC/3.45.1/org/slf4j/slf4j-api/1.7.36/slf4j-api-1.7.36.jar</url>
|
||||
</library>
|
||||
</libraries>
|
||||
</data-source>
|
||||
</component>
|
||||
</project>
|
||||
@ -3,16 +3,72 @@
|
||||
# @Author : 河瞬
|
||||
# @FileName: manage_user.py
|
||||
# @Software: PyCharm
|
||||
from fastapi import HTTPException, Response, Depends, APIRouter
|
||||
from typing import Optional, Annotated
|
||||
from datetime import datetime, timedelta
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from models import Tenant, User, Project
|
||||
from dependencies import *
|
||||
from fastapi import HTTPException, APIRouter, Depends, Request
|
||||
from sqlmodel import select, Session
|
||||
from models import User, Tenant
|
||||
from dependencies import SessionDep, get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# 枚举成员
|
||||
@router.get("/api/s1/user")
|
||||
async def list_users(request: Request, session: SessionDep, current_user: User = Depends(get_current_user)):
|
||||
if current_user.role != 1:
|
||||
raise HTTPException(status_code=403, detail="Only admin users can list users")
|
||||
|
||||
users = session.exec(select(User).where(User.tenant_id == current_user.tenant_id)).all()
|
||||
user_list = [{"username": user.username, "role": user.role} for user in users]
|
||||
return user_list
|
||||
|
||||
|
||||
# 新增和修改成员
|
||||
@router.post("/api/s1/user")
|
||||
async def add_or_update_user(data: dict, session: SessionDep, current_user: User = Depends(get_current_user)):
|
||||
if current_user.role != 1:
|
||||
raise HTTPException(status_code=403, detail="Only admin users can add or update users")
|
||||
|
||||
username = data.get("username")
|
||||
password = data.get("password")
|
||||
role = data.get("role")
|
||||
|
||||
if not username or not role:
|
||||
raise HTTPException(status_code=400, detail="Username and role are required")
|
||||
|
||||
user = session.exec(select(User).where(User.username == username, User.tenant_id == current_user.tenant_id)).first()
|
||||
|
||||
if user:
|
||||
if password == "":
|
||||
user.role = role
|
||||
else:
|
||||
user.password = password
|
||||
user.role = role
|
||||
session.add(user)
|
||||
session.commit()
|
||||
return {"detail": "User updated successfully"}
|
||||
else:
|
||||
if password == "":
|
||||
raise HTTPException(status_code=400, detail="Password is required for new user")
|
||||
new_user = User(username=username, password=password, role=role, tenant_id=current_user.tenant_id)
|
||||
session.add(new_user)
|
||||
session.commit()
|
||||
return {"detail": "User added successfully"}
|
||||
|
||||
|
||||
# 删除成员
|
||||
@router.delete("/api/s1/user")
|
||||
async def delete_user(username: str, session: SessionDep, current_user: User = Depends(get_current_user)):
|
||||
if current_user.role != 1:
|
||||
raise HTTPException(status_code=403, detail="Only admin users can delete users")
|
||||
|
||||
# username = data.get("username")
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Username is required")
|
||||
|
||||
user = session.exec(select(User).where(User.username == username, User.tenant_id == current_user.tenant_id)).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
session.delete(user)
|
||||
session.commit()
|
||||
return {"detail": "User deleted successfully"}
|
||||
|
||||
@ -4,10 +4,12 @@
|
||||
# @FileName: dependencies.py
|
||||
# @Software: PyCharm
|
||||
from typing import Annotated
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, Request, HTTPException, Cookie, Response
|
||||
from jose import jwt, JWTError
|
||||
from database import engine
|
||||
from sqlmodel import Session
|
||||
from sqlmodel import Session, select
|
||||
from config import Settings
|
||||
from models import User
|
||||
|
||||
|
||||
def get_session():
|
||||
@ -22,3 +24,27 @@ def get_settings():
|
||||
SessionDep = Annotated[Session, Depends(get_session)]
|
||||
|
||||
SettingsDep = get_settings()
|
||||
|
||||
|
||||
def get_current_user(response: Response, session_token: Annotated[str | None, Cookie()] = None, db: SessionDep = None,
|
||||
settings: SettingsDep = SettingsDep):
|
||||
if not session_token:
|
||||
response.set_cookie(key="session_token", value="", httponly=True)
|
||||
raise HTTPException(status_code=401, detail="Not authenticated", )
|
||||
|
||||
try:
|
||||
payload = jwt.decode(session_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
user_id = payload.get("id")
|
||||
if user_id is None:
|
||||
response.set_cookie(key="session_token", value="", httponly=True)
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
except JWTError:
|
||||
response.set_cookie(key="session_token", value="", httponly=True)
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
user = db.exec(select(User).where(User.id == user_id)).first()
|
||||
if not user:
|
||||
response.set_cookie(key="session_token", value="", httponly=True)
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
return user
|
||||
|
||||
3
test/.env
Normal file
3
test/.env
Normal file
@ -0,0 +1,3 @@
|
||||
ALGORITHM=HS256
|
||||
DATABASE_URL=sqlite:///:memory:
|
||||
SECRET_KEY=your_secret_key
|
||||
129
test/test_login_reg.py
Normal file
129
test/test_login_reg.py
Normal file
@ -0,0 +1,129 @@
|
||||
import unittest
|
||||
from main import app # 假设你的FastAPI应用实例名为app
|
||||
from fastapi import Depends
|
||||
from fastapi.testclient import TestClient
|
||||
from models import User, Tenant
|
||||
# from database import engine,create_db_and_tables
|
||||
from sqlmodel import Session
|
||||
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine
|
||||
|
||||
|
||||
class Hero(SQLModel, table=True):
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
name: str
|
||||
secret_name: str
|
||||
age: int | None = None
|
||||
|
||||
|
||||
sqlite_file_name = ":memory:"
|
||||
sqlite_url = f"sqlite:///{sqlite_file_name}"
|
||||
|
||||
engine = create_engine(sqlite_url, echo=True)
|
||||
|
||||
|
||||
def create_db_and_tables():
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
|
||||
# 创建一个测试客户端
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
class TestLoginReg(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
create_db_and_tables()
|
||||
|
||||
def setUp(self):
|
||||
# 创建一个模拟的数据库会话
|
||||
self.session = Session(engine)
|
||||
|
||||
def tearDown(self):
|
||||
# 清理数据库
|
||||
# self.session.delete()
|
||||
self.session.close()
|
||||
|
||||
def test_login(self):
|
||||
# 创建测试用户
|
||||
test_user = User(username="testuser", password="testpassword", role=1)
|
||||
self.session.add(test_user)
|
||||
self.session.commit()
|
||||
|
||||
# 发送登录请求
|
||||
response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.json(), {"message": "Login successful"})
|
||||
|
||||
def test_register(self):
|
||||
# 发送注册请求
|
||||
response = client.post("/api/s1/register",
|
||||
json={"name": "testtenant", "username": "newuser", "password": "newpassword"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.json(), {"detail": "注册成功"})
|
||||
|
||||
def test_list_users(self):
|
||||
# 创建测试用户
|
||||
test_user = User(username="testuser", password="testpassword", role=1)
|
||||
self.session.add(test_user)
|
||||
self.session.commit()
|
||||
|
||||
# 发送登录请求以获取cookie
|
||||
login_response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"})
|
||||
self.assertEqual(login_response.status_code, 200)
|
||||
cookie = login_response.cookies.get("session_token")
|
||||
|
||||
# 发送枚举成员请求
|
||||
response = client.get("/api/s1/user", cookies={"session_token": cookie})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
# 这里可以添加更多的断言来检查返回的用户列表
|
||||
|
||||
def test_add_or_update_user(self):
|
||||
# 创建测试用户
|
||||
test_user = User(username="testuser", password="testpassword", role=1)
|
||||
self.session.add(test_user)
|
||||
self.session.commit()
|
||||
|
||||
# 发送登录请求以获取cookie
|
||||
login_response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"})
|
||||
self.assertEqual(login_response.status_code, 200)
|
||||
cookie = login_response.cookies.get("session_token")
|
||||
|
||||
# 测试新增成员
|
||||
response = client.post("/api/s1/user",
|
||||
json={"username": "newuser", "password": "newpassword", "role": "estimator"},
|
||||
cookies={"session_token": cookie})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.json(), {"detail": "User added successfully"})
|
||||
|
||||
# 测试修改成员
|
||||
response = client.post("/api/s1/user", json={"username": "newuser", "password": "", "role": "auditor"},
|
||||
cookies={"session_token": cookie})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.json(), {"detail": "User updated successfully"})
|
||||
|
||||
def test_delete_user(self):
|
||||
# 创建测试用户
|
||||
test_user = User(username="testuser", password="testpassword", role=1)
|
||||
self.session.add(test_user)
|
||||
self.session.commit()
|
||||
|
||||
# 发送登录请求以获取cookie
|
||||
login_response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"})
|
||||
self.assertEqual(login_response.status_code, 200)
|
||||
cookie = login_response.cookies.get("session_token")
|
||||
|
||||
# 添加一个用户用于测试删除
|
||||
add_response = client.post("/api/s1/user",
|
||||
json={"username": "todeleteuser", "password": "deletepassword", "role": "estimator"},
|
||||
cookies={"session_token": cookie})
|
||||
self.assertEqual(add_response.status_code, 200)
|
||||
|
||||
# 发送删除用户请求
|
||||
response = client.delete("/api/s1/user", json={"username": "todeleteuser"}, cookies={"session_token": cookie})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.json(), {"detail": "User deleted successfully"})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user