Compare commits

..

15 Commits

Author SHA1 Message Date
4ef428c620 项目管理对注释进行了修改,租户管理权限验证完成。 2024-11-20 22:04:13 +08:00
7af908a2af 修复tenant_id不能为空的问题
(cherry picked from commit 298830cd1e)
2024-11-20 21:18:29 +08:00
8d7d6f95ba 修复tenant_id不能为空的问题
(cherry picked from commit 8f2745aa41)
2024-11-20 21:17:24 +08:00
ee96d2b22e 项目管理权限验证完成,添加了登录的普通用户可以查看所属项目的项目信息 2024-11-20 21:02:44 +08:00
e86d299dc1 Merge branch 'refs/heads/dev/pjq'
# Conflicts:
#	api/manage_project.py
#	api/manage_tanant.py
#	api/manage_user.py
#	dependencies.py
2024-11-20 20:02:42 +08:00
bb11107f47 cors 2024-11-20 18:49:11 +08:00
15f467c23f 新增登录时返回用户类型 2024-11-20 18:15:09 +08:00
febc1eaca0 添加mysql支持 2024-11-20 17:49:51 +08:00
6f04dd699c 添加mysql支持 2024-11-20 17:32:32 +08:00
0912abdf29 连接到远程服务器测试通过。 2024-11-20 17:17:50 +08:00
dc40a82f93 更新一下DataSource 2024-11-20 16:26:52 +08:00
39ab82b655 修复了单元测试中发现的问题 2024-11-20 16:26:11 +08:00
60231d5272 优化import 2024-11-20 16:25:36 +08:00
de10d61129 用户管理完成。单元测试没成功。 2024-11-19 22:45:12 +08:00
f84a035640 注册功能成功。 2024-11-19 21:23:59 +08:00
11 changed files with 378 additions and 87 deletions

View File

@ -1,7 +1,29 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true"> <component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="test" uuid="a5872518-f7b2-4c54-8ab8-96976a3c432a"> <data-source source="LOCAL" name="@120.53.31.148" uuid="4aaf174b-2e9f-410b-829d-922cb73ec818">
<driver-ref>mysql.8</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>com.mysql.cj.jdbc.Driver</jdbc-driver>
<jdbc-url>jdbc:mysql://120.53.31.148:3306</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
<data-source source="LOCAL" name="test" uuid="5db0e146-97a9-49aa-8412-64ee434ba67f">
<driver-ref>sqlite.xerial</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
<jdbc-url>jdbc:sqlite:E:\pyprojects\CostEvalPlatform\test\test.sqlite3</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
<libraries>
<library>
<url>file://$APPLICATION_CONFIG_DIR$/jdbc-drivers/Xerial SQLiteJDBC/3.45.1/org/xerial/sqlite-jdbc/3.45.1.0/sqlite-jdbc-3.45.1.0.jar</url>
</library>
<library>
<url>file://$APPLICATION_CONFIG_DIR$/jdbc-drivers/Xerial SQLiteJDBC/3.45.1/org/slf4j/slf4j-api/1.7.36/slf4j-api-1.7.36.jar</url>
</library>
</libraries>
</data-source>
<data-source source="LOCAL" name="main" uuid="7646579f-b137-4fd7-84f1-5179640337e5">
<driver-ref>sqlite.xerial</driver-ref> <driver-ref>sqlite.xerial</driver-ref>
<synchronize>true</synchronize> <synchronize>true</synchronize>
<jdbc-driver>org.sqlite.JDBC</jdbc-driver> <jdbc-driver>org.sqlite.JDBC</jdbc-driver>

View File

@ -3,15 +3,13 @@
# @Author : 河瞬 # @Author : 河瞬
# @FileName: login_reg.py # @FileName: login_reg.py
# @Software: PyCharm # @Software: PyCharm
from fastapi import HTTPException, Response, Depends, APIRouter
from typing import Optional, Annotated
from datetime import datetime, timedelta from datetime import datetime, timedelta
from jose import JWTError, jwt from typing import Optional
from sqlmodel import select from fastapi import APIRouter
from models import Tenant, User, Project
from dependencies import * from dependencies import *
from models import Tenant
router = APIRouter() router = APIRouter()
@ -32,12 +30,14 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None, s
# 登录路由 # 登录路由
@router.post("/api/s1/login") @router.post("/api/s1/login")
async def login(response: Response, user_data: dict, session: SessionDep): async def login(response: Response, user_data: dict, session: SessionDep):
if user_data.get('username') is None or user_data.get('password') is None:
raise HTTPException(status_code=401, detail="用户名或密码不能为空")
# 查询用户 # 查询用户
user = session.exec(select(User).where(User.username == user_data['username'])).first() user = session.exec(select(User).where(User.username == user_data['username'])).first()
# 验证用户名和密码 # 验证用户名和密码
if not user or user.password != user_data['password']: if not user or user.password != user_data['password']:
raise HTTPException(status_code=401, detail="Login failed") raise HTTPException(status_code=401, detail="登录失败,用户名或密码错误")
# 生成JWT token # 生成JWT token
token = create_access_token(data={"id": user.id, "role": user.role, "tanant_id": user.tenant.id}) token = create_access_token(data={"id": user.id, "role": user.role, "tanant_id": user.tenant.id})
@ -48,4 +48,20 @@ async def login(response: Response, user_data: dict, session: SessionDep):
# 关闭数据库会话 # 关闭数据库会话
session.close() session.close()
return {"message": f"Login successful"} return {"message": f"Login successful", "role": user.role}
@router.post("/api/s1/register")
async def register(data: dict, session: SessionDep):
if session.exec(select(Tenant).where(Tenant.name == data['name'])).first():
raise HTTPException(status_code=409, detail="租户名已存在")
if session.exec(select(User).where(User.username == data['username'])).first():
raise HTTPException(status_code=409, detail="用户名已存在")
tenant = Tenant(name=data['name'])
user = User(username=data['username'], password=data['password'], role=1, tenant=tenant)
session.add(tenant)
session.add(user)
session.commit()
session.close()
return {"detail": "注册成功"}

View File

@ -12,6 +12,7 @@ from sqlmodel import select
from models import Tenant, User, Project, ProjectUserLink from models import Tenant, User, Project, ProjectUserLink
from dependencies import * from dependencies import *
from fastapi import APIRouter
from typing import List from typing import List
@ -21,10 +22,29 @@ TenantRole = 1
# 列举所有项目 # 列举所有项目
@router.get("/api/s1/project") @router.get("/api/s1/project")
async def get_project(response: Response, session: SessionDep): async def get_project(response: Response, session: SessionDep, current_user: User = Depends(get_current_user)):
projects = session.query(Project).filter().all() # 只有角色为 0、1、2 或 3 的用户才可以访问
if current_user.role == 0:
# 角色为0显示所有项目
projects = session.query(Project).all()
elif current_user.role == 1:
# 角色为1显示tenant_id匹配的项目即属于当前租户的项目
projects = session.query(Project).filter(Project.owner_id == current_user.tenant_id).all()
elif current_user.role in [2, 3]:
# 角色为2或3显示与当前用户相关联的项目
projects = (
session.query(Project)
.join(ProjectUserLink)
.filter(ProjectUserLink.user_id == current_user.id)
.all()
)
else:
raise HTTPException(status_code=403, detail="You do not have permission to view projects.")
if not projects: if not projects:
raise HTTPException(status_code=404, detail="Project not found") raise HTTPException(status_code=404, detail="Project not found or you have no projects.")
# 返回项目的基本信息
return { return {
"projects": [ "projects": [
{ {
@ -37,9 +57,13 @@ async def get_project(response: Response, session: SessionDep):
] ]
} }
#新增与修改项目
# 新增与修改项目
@router.post("/api/s1/project") @router.post("/api/s1/project")
async def create_project(data: dict, session: SessionDep): async def create_project(data: dict, session: SessionDep, current_user: User = Depends(get_current_user)):
if current_user.role != 1:
raise HTTPException(status_code=403, detail="Only Tenant admin users can add or update projects.")
project_id = data.get("project_id") project_id = data.get("project_id")
name = data["name"] name = data["name"]
requirement = data["requirement"] requirement = data["requirement"]
@ -88,7 +112,7 @@ async def create_project(data: dict, session: SessionDep):
else: else:
# 新增项目 # 新增项目
exist_project = session.exec(select(Project).where(Project.name == name)).first() exist_project = session.exec(select(Project).where(Project.name == name)).first()
print(exist_project) #测试用 print(exist_project) # 测试用
if exist_project: if exist_project:
raise HTTPException(status_code=404, detail="Project already exists") raise HTTPException(status_code=404, detail="Project already exists")
@ -104,7 +128,7 @@ async def create_project(data: dict, session: SessionDep):
# 处理项目和用户的关联 # 处理项目和用户的关联
# 先清除现有的关联 # 先清除现有的关联
# 生成删除语句并执行 # 生成删除语句并执行
print(project_id) #测试用 print(project_id) # 测试用
stmt = delete(ProjectUserLink).where(ProjectUserLink.project_id == project.id) stmt = delete(ProjectUserLink).where(ProjectUserLink.project_id == project.id)
session.execute(stmt) session.execute(stmt)
session.commit() # 提交事务 session.commit() # 提交事务
@ -130,9 +154,13 @@ async def create_project(data: dict, session: SessionDep):
"information": project, "information": project,
} }
#删除项目
# 删除项目
@router.delete("/api/s1/project") @router.delete("/api/s1/project")
async def delete_project(data: dict, session: SessionDep): async def delete_project(data: dict, session: SessionDep, current_user: User = Depends(get_current_user)):
if current_user.role != 1:
raise HTTPException(status_code=403, detail="Only Tenant admin users can delete projects.")
project_name = data.get("name") project_name = data.get("name")
# 权限检查:只有管理员才可以删除项目 # 权限检查:只有管理员才可以删除项目
@ -147,7 +175,7 @@ async def delete_project(data: dict, session: SessionDep):
select(Project).where(Project.name == project_name)).first() select(Project).where(Project.name == project_name)).first()
if not project: if not project:
raise HTTPException(status_code=404,detail="Project not found") raise HTTPException(status_code=404, detail="Project not found")
# 删除与项目相关的用户链接 # 删除与项目相关的用户链接
# 先清除现有的关联 # 先清除现有的关联
@ -158,4 +186,4 @@ async def delete_project(data: dict, session: SessionDep):
session.delete(project) session.delete(project)
session.commit() session.commit()
return {"detail": "Project deleted successfully"} return {"detail": "Project deleted successfully"}

View File

@ -2,6 +2,7 @@
# @Time : 2024/11/19 下午8:04 # @Time : 2024/11/19 下午8:04
# @FileName: manage_tanant.py # @FileName: manage_tanant.py
# @Software: PyCharm # @Software: PyCharm
from fastapi import APIRouter
from fastapi import HTTPException, Response, Depends, APIRouter from fastapi import HTTPException, Response, Depends, APIRouter
from typing import Optional, Annotated from typing import Optional, Annotated
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -22,7 +23,11 @@ from dependencies import SessionDep # 假设 SessionDep 是数据库会话的
#列举所有租户 #列举所有租户
@router.get("/api/s1/tenant") @router.get("/api/s1/tenant")
async def get_tenant(response: Response, session: SessionDep): async def get_tenant(response: Response, session: SessionDep, current_user: User = Depends(get_current_user)):
if current_user.role != 0:
raise HTTPException(status_code=403, detail="Only Superadmin can list all tenants.")
tenants = session.query(Tenant).all() # 获取所有租户 tenants = session.query(Tenant).all() # 获取所有租户
if not tenants: if not tenants:
raise HTTPException(status_code=404, detail="No tenants found") raise HTTPException(status_code=404, detail="No tenants found")
@ -48,44 +53,21 @@ async def get_tenant(response: Response, session: SessionDep):
# 新增和修改租户 # 新增和修改租户
@router.post("/api/s1/tenant") @router.post("/api/s1/tenant")
async def create_or_update_tenant(data: dict, session: SessionDep): async def create_or_update_tenant(data: dict, session: SessionDep, current_user: User = Depends(get_current_user)):
name = data["name"] if current_user.role != 0:
username = data["username"] raise HTTPException(status_code=403, detail="Only Superadmin can add or update tenants.")
password = data.get("password", "") # 默认为空字符串
name = data.get("name")
username = data.get("username")
password = data.get("password")
# 验证是否缺少必要参数 # 验证是否缺少必要参数
if not name or not username: if not name:
raise HTTPException(status_code=400, detail="Need more name/username") raise HTTPException(status_code=400, detail="Need more name")
# 查找用户 if username:
user_query = select(User).where(User.username == username) # 如果 username 不为空,判断为新建租户
existing_user = session.exec(user_query).first() # 检查租户名是否已存在
# 如果密码为空,更新租户信息
if password == "":
print("密码为空") #测试用
# 如果用户不存在,返回错误
if not existing_user:
raise HTTPException(status_code=404, detail="User not found")
else:
# 如果找到了对应的 User
# 使用 user.tenant_id 查找对应的 Tenant
tenant = session.get(Tenant, existing_user.tenant_id)
# 如果 Tenant 存在,更新 Tenant 的 name 字段
if tenant:
tenant.name = name
session.commit() # 提交更新
else:
raise HTTPException(status_code=404, detail="Tenant not found")
return {"message": "Tenant and User update successfully"}
else:
print("密码不为空") #测试用
# 如果密码不为空,执行创建新租户和用户的操作
if existing_user:
# 如果用户已存在,返回错误
raise HTTPException(status_code=409, detail="User already exists")
# 检查租户是否已存在
tenant_query = select(Tenant).where(Tenant.name == name) tenant_query = select(Tenant).where(Tenant.name == name)
existing_tenant = session.exec(tenant_query).first() existing_tenant = session.exec(tenant_query).first()
@ -93,11 +75,7 @@ async def create_or_update_tenant(data: dict, session: SessionDep):
raise HTTPException(status_code=409, detail="Tenant name already exists") raise HTTPException(status_code=409, detail="Tenant name already exists")
# 创建新租户 # 创建新租户
tenant = Tenant( tenant = Tenant(name=name)
name=name,
username=username,
password=password, # 实际使用时应加密密码
)
session.add(tenant) session.add(tenant)
session.commit() session.commit()
session.refresh(tenant) session.refresh(tenant)
@ -105,21 +83,48 @@ async def create_or_update_tenant(data: dict, session: SessionDep):
# 创建新用户 # 创建新用户
user = User( user = User(
username=username, username=username,
password=password, # 同样需要加密密码 password=password, # 记得加密密码
role=1, # 默认role为1 role=1, # 默认role为1
tenant_id = tenant.id, tenant_id=tenant.id,
) )
session.add(user) session.add(user)
# 提交事务 # 提交事务
session.commit() session.commit()
session.refresh(tenant)
return {"message": "Tenant and User added successfully"} return {"message": "Tenant and User added successfully"}
else:
# 如果 username 为空,执行更新操作
# 根据租户名称查找 Tenant
tenant_query = select(Tenant).where(Tenant.name == name)
tenant = session.exec(tenant_query).first()
# 如果找不到对应的租户,抛出错误
if not tenant:
raise HTTPException(status_code=404, detail="Tenant not found")
# 找到租户后,根据 tenant_id 查找该租户下的所有用户
user_query = select(User).where(User.tenant_id == tenant.id)
user = session.exec(user_query).first()
#如果找不到对应的用户,抛出错误
if not user:
raise HTTPException(status_code=404, detail="User not found")
user.password = password
session.add(user)
session.commit()
print(user) #测试用
return {"message": "Tenant and User update successfully"}
#删除租户 #删除租户
@router.delete("/api/s1/tenant") @router.delete("/api/s1/tenant")
async def delete_tenant(data: dict, session: SessionDep): async def delete_tenant(data: dict, session: SessionDep, current_user: User = Depends(get_current_user)):
if current_user.role != 0:
raise HTTPException(status_code=403, detail="Only Superadmin can delete tenants.")
tenant_name = data.get("name") tenant_name = data.get("name")
if not tenant_name: if not tenant_name:

View File

@ -3,19 +3,74 @@
# @Author : 河瞬 # @Author : 河瞬
# @FileName: manage_user.py # @FileName: manage_user.py
# @Software: PyCharm # @Software: PyCharm
from fastapi import HTTPException, Response, Depends, APIRouter from fastapi import HTTPException, APIRouter, Depends, Request
from typing import Optional, Annotated
from datetime import datetime, timedelta
from jose import JWTError, jwt
from sqlmodel import select from sqlmodel import select
from models import Tenant, User, Project from dependencies import SessionDep, get_current_user
from dependencies import * from models import User
router = APIRouter() router = APIRouter()
# @router.get(...) # 枚举成员
# def example(): @router.get("/api/s1/user")
# return "hello" async def list_users(request: Request, session: SessionDep, current_user: User = Depends(get_current_user)):
if current_user.role != 1:
raise HTTPException(status_code=403, detail="Only admin users can list users")
users = session.exec(select(User).where(User.tenant_id == current_user.tenant_id)).all()
user_list = [{"username": user.username, "role": user.role} for user in users]
return user_list
# 新增和修改成员
@router.post("/api/s1/user")
async def add_or_update_user(data: dict, session: SessionDep, current_user: User = Depends(get_current_user)):
if current_user.role != 1:
raise HTTPException(status_code=403, detail="Only admin users can add or update users")
username = data.get("username")
password = data.get("password")
role = data.get("role")
if role not in ["auditor", "estimator"]:
raise HTTPException(status_code=400, detail="Invalid role")
role = 2 if role == "estimator" else 3
if not username or not role:
raise HTTPException(status_code=400, detail="Username and role are required")
user = session.exec(select(User).where(User.username == username, User.tenant_id == current_user.tenant_id)).first()
if user:
if password and password != "":
user.password = password
user.role = role
session.add(user)
session.commit()
return {"detail": "User updated successfully"}
else:
if password == "":
raise HTTPException(status_code=400, detail="Password is required for new user")
new_user = User(username=username, password=password, role=role, tenant_id=current_user.tenant_id)
session.add(new_user)
session.commit()
return {"detail": "User added successfully"}
# 删除成员
@router.delete("/api/s1/user")
async def delete_user(username: str, session: SessionDep, current_user: User = Depends(get_current_user)):
if current_user.role != 1:
raise HTTPException(status_code=403, detail="Only admin users can delete users")
# username = data.get("username")
if not username:
raise HTTPException(status_code=422, detail="Username is required")
user = session.exec(select(User).where(User.username == username, User.tenant_id == current_user.tenant_id)).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
session.delete(user)
session.commit()
return {"detail": "User deleted successfully"}

View File

@ -4,11 +4,13 @@
# @FileName: dependencies.py # @FileName: dependencies.py
# @Software: PyCharm # @Software: PyCharm
from typing import Annotated from typing import Annotated
from fastapi import Depends, Request, HTTPException, Cookie, Response
from fastapi import Depends, HTTPException, Cookie, Response
from jose import jwt, JWTError from jose import jwt, JWTError
from database import engine
from sqlmodel import Session, select from sqlmodel import Session, select
from config import Settings from config import Settings
from database import engine
from models import User from models import User

21
main.py
View File

@ -1,16 +1,10 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Response, Depends, APIRouter from fastapi import FastAPI
from typing import Optional, Annotated from fastapi.middleware.cors import CORSMiddleware
from datetime import datetime, timedelta
from jose import JWTError, jwt
from sqlmodel import Session, select
from database import create_db_and_tables, engine
from models import Tenant, User, Project
from dependencies import *
from api import login_reg, manage_project, manage_tanant, manage_user from api import login_reg, manage_project, manage_tanant, manage_user
from database import create_db_and_tables
# 用于生成和验证JWT的密钥 # 用于生成和验证JWT的密钥
SECRET_KEY = "your_secret_key" SECRET_KEY = "your_secret_key"
@ -27,6 +21,15 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源,也可以指定具体的来源,例如 ["http://example.com", "https://example.com"]
allow_credentials=True, # 允许携带凭证如cookies
allow_methods=["*"], # 允许所有方法,也可以指定具体的方法,例如 ["GET", "POST", "PUT", "DELETE"]
allow_headers=["*"], # 允许所有头部,也可以指定具体的头部,例如 ["Content-Type", "Authorization"]
)
app.include_router(login_reg.router) app.include_router(login_reg.router)
app.include_router(manage_tanant.router) app.include_router(manage_tanant.router)
app.include_router(manage_user.router) app.include_router(manage_user.router)

View File

@ -6,7 +6,7 @@
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from sqlmodel import Field, Relationship, SQLModel, DateTime from sqlmodel import Field, Relationship, SQLModel
class Tenant(SQLModel, table=True): class Tenant(SQLModel, table=True):
@ -29,7 +29,7 @@ class User(SQLModel, table=True):
username: str = Field(index=True) username: str = Field(index=True)
password: str password: str
role: int role: int
tenant_id: int = Field(default=None, foreign_key="Tenant.id") tenant_id: int | None = Field(default=None, foreign_key="Tenant.id")
tenant: Tenant = Relationship(back_populates="users") tenant: Tenant = Relationship(back_populates="users")
projects: List["Project"] = Relationship(back_populates="users", link_model=ProjectUserLink) projects: List["Project"] = Relationship(back_populates="users", link_model=ProjectUserLink)

View File

@ -3,4 +3,5 @@ python-jose~=3.3.0
uvicorn~=0.32.0 uvicorn~=0.32.0
pydantic~=2.9.2 pydantic~=2.9.2
pydantic-settings~=2.6.1 pydantic-settings~=2.6.1
mysqlclient
sqlmodel~=0.0.22 sqlmodel~=0.0.22

3
test/.env Normal file
View File

@ -0,0 +1,3 @@
ALGORITHM=HS256
DATABASE_URL=sqlite:///test.sqlite3
SECRET_KEY=your_secret_key

156
test/test_login_reg.py Normal file
View File

@ -0,0 +1,156 @@
import unittest
from fastapi.testclient import TestClient
# from database import engine,create_db_and_tables
from sqlmodel import select, Session, SQLModel, create_engine
from config import Settings
from main import app # 假设你的FastAPI应用实例名为app
from models import User, Tenant
settings = Settings()
engine = create_engine(settings.DATABASE_URL, echo=True)
def create_db_and_tables():
SQLModel.metadata.create_all(engine)
# 创建一个测试客户端
client = TestClient(app)
session: Session = None
class TestLoginReg(unittest.TestCase):
@classmethod
def setUpClass(cls):
global session
create_db_and_tables()
session = Session(engine)
session.close()
session = Session(engine)
@classmethod
def tearDownClass(cls):
session.close()
def setUp(self):
for item in session.exec(select(User)).all():
session.delete(item)
session.commit()
for item in session.exec(select(Tenant)).all():
session.delete(item)
session.commit()
# 创建一个模拟的数据库会话
# session.add()
def tearDown(self):
# 清理数据库
for item in session.exec(select(User)).all():
session.delete(item)
session.commit()
for item in session.exec(select(Tenant)).all():
session.delete(item)
session.commit()
# session.delete()
def test_setup_teardown(self):
pass
def test_login(self):
# 创建测试用户
test_tanant = Tenant(name="testtenant")
session.add(test_tanant)
session.commit()
test_user = User(username="testuser", password="testpassword", role=1, tenant=test_tanant)
session.add(test_user)
session.commit()
# 发送登录请求
response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"message": "Login successful"})
def test_register(self):
# 发送注册请求
response = client.post("/api/s1/register",
json={"name": "regtenant", "username": "reguser", "password": "regpassword"})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"detail": "注册成功"})
response = client.post("/api/s1/register",
json={"name": "regtenant", "username": "rereguser", "password": "regpassword"})
self.assertEqual(response.status_code, 409)
self.assertEqual(response.json(), {"detail": "租户名已存在"})
response = client.post("/api/s1/register",
json={"name": "regretenant", "username": "reguser", "password": "regpassword"})
self.assertEqual(response.status_code, 409)
self.assertEqual(response.json(), {"detail": "用户名已存在"})
def get_cookie(self):
# 创建测试用户
test_tanant = Tenant(name="cookietenant")
session.add(test_tanant)
session.commit()
test_user = User(username="cookieuser", password="cookiepassword", role=1, tenant=test_tanant)
session.add(test_user)
session.commit()
# 发送登录请求以获取cookie
login_response = client.post("/api/s1/login", json={"username": "cookieuser", "password": "cookiepassword"})
self.assertEqual(login_response.status_code, 200)
return login_response.cookies.get("session_token")
def test_list_users(self):
# 发送枚举成员请求
cookie = self.get_cookie()
tenant = session.exec(select(Tenant).where(Tenant.name == "cookietenant")).first()
users = [User(username=f"testuser{i}", password="testpassword", role=2, tenant=tenant) for i in range(5)]
session.add_all(users)
session.commit()
response = client.get("/api/s1/user", cookies={"session_token": cookie})
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.json()), 6)
# 这里可以添加更多的断言来检查返回的用户列表
def test_add_or_update_user(self):
cookie = self.get_cookie()
# 测试新增成员
response = client.post("/api/s1/user",
json={"username": "newuser", "password": "newpassword", "role": "estimator"},
cookies={"session_token": cookie})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"detail": "User added successfully"})
# 测试修改成员
response = client.post("/api/s1/user", json={"username": "newuser", "password": "", "role": "auditor"},
cookies={"session_token": cookie})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"detail": "User updated successfully"})
user = session.exec(select(User).where(User.username == "newuser")).first()
self.assertEqual(user.role, 3)
def test_delete_user(self):
cookie = self.get_cookie()
# 添加一个用户用于测试删除
add_response = client.post("/api/s1/user",
json={"username": "todeleteuser", "password": "deletepassword", "role": "estimator"},
cookies={"session_token": cookie})
self.assertEqual(add_response.status_code, 200)
# 发送删除用户请求
# 错误请求
response = client.delete("/api/s1/user", params={"asdf": "notexistuser"}, cookies={"session_token": cookie})
self.assertEqual(response.status_code, 422)
# 不存在的用户
response = client.delete("/api/s1/user", params={"username": "notexistuser"}, cookies={"session_token": cookie})
self.assertEqual(response.status_code, 404)
# 正确的请求
response = client.delete("/api/s1/user", params={"username": "todeleteuser"}, cookies={"session_token": cookie})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"detail": "User deleted successfully"})
if __name__ == '__main__':
unittest.main()