Compare commits

...

5 Commits

Author SHA1 Message Date
dc40a82f93 更新一下DataSource 2024-11-20 16:26:52 +08:00
39ab82b655 修复了单元测试中发现的问题 2024-11-20 16:26:11 +08:00
60231d5272 优化import 2024-11-20 16:25:36 +08:00
de10d61129 用户管理完成。单元测试没成功。 2024-11-19 22:45:12 +08:00
f84a035640 注册功能成功。 2024-11-19 21:23:59 +08:00
10 changed files with 306 additions and 55 deletions

View File

@ -1,7 +1,29 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true"> <component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="test" uuid="a5872518-f7b2-4c54-8ab8-96976a3c432a"> <data-source source="LOCAL" name="@120.53.31.148" uuid="4aaf174b-2e9f-410b-829d-922cb73ec818">
<driver-ref>mysql.8</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>com.mysql.cj.jdbc.Driver</jdbc-driver>
<jdbc-url>jdbc:mysql://120.53.31.148:3306</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
<data-source source="LOCAL" name="test" uuid="5db0e146-97a9-49aa-8412-64ee434ba67f">
<driver-ref>sqlite.xerial</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
<jdbc-url>jdbc:sqlite:E:\pyprojects\CostEvalPlatform\test\test.sqlite3</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
<libraries>
<library>
<url>file://$APPLICATION_CONFIG_DIR$/jdbc-drivers/Xerial SQLiteJDBC/3.45.1/org/xerial/sqlite-jdbc/3.45.1.0/sqlite-jdbc-3.45.1.0.jar</url>
</library>
<library>
<url>file://$APPLICATION_CONFIG_DIR$/jdbc-drivers/Xerial SQLiteJDBC/3.45.1/org/slf4j/slf4j-api/1.7.36/slf4j-api-1.7.36.jar</url>
</library>
</libraries>
</data-source>
<data-source source="LOCAL" name="main" uuid="7646579f-b137-4fd7-84f1-5179640337e5">
<driver-ref>sqlite.xerial</driver-ref> <driver-ref>sqlite.xerial</driver-ref>
<synchronize>true</synchronize> <synchronize>true</synchronize>
<jdbc-driver>org.sqlite.JDBC</jdbc-driver> <jdbc-driver>org.sqlite.JDBC</jdbc-driver>

View File

@ -3,15 +3,13 @@
# @Author : 河瞬 # @Author : 河瞬
# @FileName: login_reg.py # @FileName: login_reg.py
# @Software: PyCharm # @Software: PyCharm
from fastapi import HTTPException, Response, Depends, APIRouter
from typing import Optional, Annotated
from datetime import datetime, timedelta from datetime import datetime, timedelta
from jose import JWTError, jwt from typing import Optional
from sqlmodel import select from fastapi import APIRouter
from models import Tenant, User, Project
from dependencies import * from dependencies import *
from models import Tenant
router = APIRouter() router = APIRouter()
@ -37,7 +35,7 @@ async def login(response: Response, user_data: dict, session: SessionDep):
# 验证用户名和密码 # 验证用户名和密码
if not user or user.password != user_data['password']: if not user or user.password != user_data['password']:
raise HTTPException(status_code=401, detail="Login failed") raise HTTPException(status_code=401, detail="登录失败,用户名或密码错误")
# 生成JWT token # 生成JWT token
token = create_access_token(data={"id": user.id, "role": user.role, "tanant_id": user.tenant.id}) token = create_access_token(data={"id": user.id, "role": user.role, "tanant_id": user.tenant.id})
@ -49,3 +47,19 @@ async def login(response: Response, user_data: dict, session: SessionDep):
session.close() session.close()
return {"message": f"Login successful"} return {"message": f"Login successful"}
@router.post("/api/s1/register")
async def register(data: dict, session: SessionDep):
if session.exec(select(Tenant).where(Tenant.name == data['name'])).first():
raise HTTPException(status_code=409, detail="租户名已存在")
if session.exec(select(User).where(User.username == data['username'])).first():
raise HTTPException(status_code=409, detail="用户名已存在")
tenant = Tenant(name=data['name'])
user = User(username=data['username'], password=data['password'], role=1, tenant=tenant)
session.add(tenant)
session.add(user)
session.commit()
session.close()
return {"detail": "注册成功"}

View File

@ -2,19 +2,9 @@
# @Time : 2024/11/19 下午8:05 # @Time : 2024/11/19 下午8:05
# @FileName: manage_project.py # @FileName: manage_project.py
# @Software: PyCharm # @Software: PyCharm
from fastapi import HTTPException, Response, Depends, APIRouter from fastapi import APIRouter
from typing import Optional, Annotated
from datetime import datetime, timedelta
from jose import JWTError, jwt
from sqlmodel import select
from models import Tenant, User, Project
from dependencies import *
router = APIRouter() router = APIRouter()
@router.get(...)
def example():
return "hello"

View File

@ -2,19 +2,7 @@
# @Time : 2024/11/19 下午8:04 # @Time : 2024/11/19 下午8:04
# @FileName: manage_tanant.py # @FileName: manage_tanant.py
# @Software: PyCharm # @Software: PyCharm
from fastapi import HTTPException, Response, Depends, APIRouter from fastapi import APIRouter
from typing import Optional, Annotated
from datetime import datetime, timedelta
from jose import JWTError, jwt
from sqlmodel import select
from models import Tenant, User, Project
from dependencies import *
router = APIRouter() router = APIRouter()
@router.get(...)
def example():
return "hello"

View File

@ -3,19 +3,76 @@
# @Author : 河瞬 # @Author : 河瞬
# @FileName: manage_user.py # @FileName: manage_user.py
# @Software: PyCharm # @Software: PyCharm
from fastapi import HTTPException, Response, Depends, APIRouter from fastapi import HTTPException, APIRouter, Depends, Request
from typing import Optional, Annotated
from datetime import datetime, timedelta
from jose import JWTError, jwt
from sqlmodel import select from sqlmodel import select
from models import Tenant, User, Project from dependencies import SessionDep, get_current_user
from dependencies import * from models import User
router = APIRouter() router = APIRouter()
@router.get(...) # 枚举成员
def example(): @router.get("/api/s1/user")
return "hello" async def list_users(request: Request, session: SessionDep, current_user: User = Depends(get_current_user)):
if current_user.role != 1:
raise HTTPException(status_code=403, detail="Only admin users can list users")
users = session.exec(select(User).where(User.tenant_id == current_user.tenant_id)).all()
user_list = [{"username": user.username, "role": user.role} for user in users]
return user_list
# 新增和修改成员
@router.post("/api/s1/user")
async def add_or_update_user(data: dict, session: SessionDep, current_user: User = Depends(get_current_user)):
if current_user.role != 1:
raise HTTPException(status_code=403, detail="Only admin users can add or update users")
username = data.get("username")
password = data.get("password")
role = data.get("role")
if role not in ["auditor", "estimator"]:
raise HTTPException(status_code=400, detail="Invalid role")
role = 2 if role == "estimator" else 3
if not username or not role:
raise HTTPException(status_code=400, detail="Username and role are required")
user = session.exec(select(User).where(User.username == username, User.tenant_id == current_user.tenant_id)).first()
if user:
if password == "":
user.role = role
else:
user.password = password
user.role = role
session.add(user)
session.commit()
return {"detail": "User updated successfully"}
else:
if password == "":
raise HTTPException(status_code=400, detail="Password is required for new user")
new_user = User(username=username, password=password, role=role, tenant_id=current_user.tenant_id)
session.add(new_user)
session.commit()
return {"detail": "User added successfully"}
# 删除成员
@router.delete("/api/s1/user")
async def delete_user(username: str, session: SessionDep, current_user: User = Depends(get_current_user)):
if current_user.role != 1:
raise HTTPException(status_code=403, detail="Only admin users can delete users")
# username = data.get("username")
if not username:
raise HTTPException(status_code=422, detail="Username is required")
user = session.exec(select(User).where(User.username == username, User.tenant_id == current_user.tenant_id)).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
session.delete(user)
session.commit()
return {"detail": "User deleted successfully"}

View File

@ -4,10 +4,14 @@
# @FileName: dependencies.py # @FileName: dependencies.py
# @Software: PyCharm # @Software: PyCharm
from typing import Annotated from typing import Annotated
from fastapi import Depends
from database import engine from fastapi import Depends, HTTPException, Cookie, Response
from sqlmodel import Session from jose import jwt, JWTError
from sqlmodel import Session, select
from config import Settings from config import Settings
from database import engine
from models import User
def get_session(): def get_session():
@ -22,3 +26,27 @@ def get_settings():
SessionDep = Annotated[Session, Depends(get_session)] SessionDep = Annotated[Session, Depends(get_session)]
SettingsDep = get_settings() SettingsDep = get_settings()
def get_current_user(response: Response, session_token: Annotated[str | None, Cookie()] = None, db: SessionDep = None,
settings: SettingsDep = SettingsDep):
if not session_token:
response.set_cookie(key="session_token", value="", httponly=True)
raise HTTPException(status_code=401, detail="Not authenticated", )
try:
payload = jwt.decode(session_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
user_id = payload.get("id")
if user_id is None:
response.set_cookie(key="session_token", value="", httponly=True)
raise HTTPException(status_code=401, detail="Invalid token")
except JWTError:
response.set_cookie(key="session_token", value="", httponly=True)
raise HTTPException(status_code=401, detail="Invalid token")
user = db.exec(select(User).where(User.id == user_id)).first()
if not user:
response.set_cookie(key="session_token", value="", httponly=True)
raise HTTPException(status_code=401, detail="User not found")
return user

11
main.py
View File

@ -1,16 +1,9 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Response, Depends, APIRouter from fastapi import FastAPI
from typing import Optional, Annotated
from datetime import datetime, timedelta
from jose import JWTError, jwt
from sqlmodel import Session, select
from database import create_db_and_tables, engine
from models import Tenant, User, Project
from dependencies import *
from api import login_reg, manage_project, manage_tanant, manage_user from api import login_reg, manage_project, manage_tanant, manage_user
from database import create_db_and_tables
# 用于生成和验证JWT的密钥 # 用于生成和验证JWT的密钥
SECRET_KEY = "your_secret_key" SECRET_KEY = "your_secret_key"

View File

@ -6,7 +6,7 @@
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from sqlmodel import Field, Relationship, SQLModel, DateTime from sqlmodel import Field, Relationship, SQLModel
class Tenant(SQLModel, table=True): class Tenant(SQLModel, table=True):

3
test/.env Normal file
View File

@ -0,0 +1,3 @@
ALGORITHM=HS256
DATABASE_URL=sqlite:///test.sqlite3
SECRET_KEY=your_secret_key

156
test/test_login_reg.py Normal file
View File

@ -0,0 +1,156 @@
import unittest
from fastapi.testclient import TestClient
# from database import engine,create_db_and_tables
from sqlmodel import select, Session, SQLModel, create_engine
from config import Settings
from main import app # 假设你的FastAPI应用实例名为app
from models import User, Tenant
settings = Settings()
engine = create_engine(settings.DATABASE_URL, echo=True)
def create_db_and_tables():
SQLModel.metadata.create_all(engine)
# 创建一个测试客户端
client = TestClient(app)
session: Session = None
class TestLoginReg(unittest.TestCase):
@classmethod
def setUpClass(cls):
global session
create_db_and_tables()
session = Session(engine)
session.close()
session = Session(engine)
@classmethod
def tearDownClass(cls):
session.close()
def setUp(self):
for item in session.exec(select(User)).all():
session.delete(item)
session.commit()
for item in session.exec(select(Tenant)).all():
session.delete(item)
session.commit()
# 创建一个模拟的数据库会话
# session.add()
def tearDown(self):
# 清理数据库
for item in session.exec(select(User)).all():
session.delete(item)
session.commit()
for item in session.exec(select(Tenant)).all():
session.delete(item)
session.commit()
# session.delete()
def test_setup_teardown(self):
pass
def test_login(self):
# 创建测试用户
test_tanant = Tenant(name="testtenant")
session.add(test_tanant)
session.commit()
test_user = User(username="testuser", password="testpassword", role=1, tenant=test_tanant)
session.add(test_user)
session.commit()
# 发送登录请求
response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"message": "Login successful"})
def test_register(self):
# 发送注册请求
response = client.post("/api/s1/register",
json={"name": "regtenant", "username": "reguser", "password": "regpassword"})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"detail": "注册成功"})
response = client.post("/api/s1/register",
json={"name": "regtenant", "username": "rereguser", "password": "regpassword"})
self.assertEqual(response.status_code, 409)
self.assertEqual(response.json(), {"detail": "租户名已存在"})
response = client.post("/api/s1/register",
json={"name": "regretenant", "username": "reguser", "password": "regpassword"})
self.assertEqual(response.status_code, 409)
self.assertEqual(response.json(), {"detail": "用户名已存在"})
def get_cookie(self):
# 创建测试用户
test_tanant = Tenant(name="cookietenant")
session.add(test_tanant)
session.commit()
test_user = User(username="cookieuser", password="cookiepassword", role=1, tenant=test_tanant)
session.add(test_user)
session.commit()
# 发送登录请求以获取cookie
login_response = client.post("/api/s1/login", json={"username": "cookieuser", "password": "cookiepassword"})
self.assertEqual(login_response.status_code, 200)
return login_response.cookies.get("session_token")
def test_list_users(self):
# 发送枚举成员请求
cookie = self.get_cookie()
tenant = session.exec(select(Tenant).where(Tenant.name == "cookietenant")).first()
users = [User(username=f"testuser{i}", password="testpassword", role=2, tenant=tenant) for i in range(5)]
session.add_all(users)
session.commit()
response = client.get("/api/s1/user", cookies={"session_token": cookie})
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.json()), 6)
# 这里可以添加更多的断言来检查返回的用户列表
def test_add_or_update_user(self):
cookie = self.get_cookie()
# 测试新增成员
response = client.post("/api/s1/user",
json={"username": "newuser", "password": "newpassword", "role": "estimator"},
cookies={"session_token": cookie})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"detail": "User added successfully"})
# 测试修改成员
response = client.post("/api/s1/user", json={"username": "newuser", "password": "", "role": "auditor"},
cookies={"session_token": cookie})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"detail": "User updated successfully"})
user = session.exec(select(User).where(User.username == "newuser")).first()
self.assertEqual(user.role, 3)
def test_delete_user(self):
cookie = self.get_cookie()
# 添加一个用户用于测试删除
add_response = client.post("/api/s1/user",
json={"username": "todeleteuser", "password": "deletepassword", "role": "estimator"},
cookies={"session_token": cookie})
self.assertEqual(add_response.status_code, 200)
# 发送删除用户请求
# 错误请求
response = client.delete("/api/s1/user", params={"asdf": "notexistuser"}, cookies={"session_token": cookie})
self.assertEqual(response.status_code, 422)
# 不存在的用户
response = client.delete("/api/s1/user", params={"username": "notexistuser"}, cookies={"session_token": cookie})
self.assertEqual(response.status_code, 404)
# 正确的请求
response = client.delete("/api/s1/user", params={"username": "todeleteuser"}, cookies={"session_token": cookie})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"detail": "User deleted successfully"})
if __name__ == '__main__':
unittest.main()