Compare commits

..

No commits in common. "dc40a82f93487af765975b8dbdae582ca6fe8224" and "077ecb8d5e3ad65bb297cb2db2bc7807fb6f65bb" have entirely different histories.

10 changed files with 55 additions and 306 deletions

View File

@ -1,29 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="@120.53.31.148" uuid="4aaf174b-2e9f-410b-829d-922cb73ec818">
<driver-ref>mysql.8</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>com.mysql.cj.jdbc.Driver</jdbc-driver>
<jdbc-url>jdbc:mysql://120.53.31.148:3306</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
<data-source source="LOCAL" name="test" uuid="5db0e146-97a9-49aa-8412-64ee434ba67f">
<driver-ref>sqlite.xerial</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
<jdbc-url>jdbc:sqlite:E:\pyprojects\CostEvalPlatform\test\test.sqlite3</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
<libraries>
<library>
<url>file://$APPLICATION_CONFIG_DIR$/jdbc-drivers/Xerial SQLiteJDBC/3.45.1/org/xerial/sqlite-jdbc/3.45.1.0/sqlite-jdbc-3.45.1.0.jar</url>
</library>
<library>
<url>file://$APPLICATION_CONFIG_DIR$/jdbc-drivers/Xerial SQLiteJDBC/3.45.1/org/slf4j/slf4j-api/1.7.36/slf4j-api-1.7.36.jar</url>
</library>
</libraries>
</data-source>
<data-source source="LOCAL" name="main" uuid="7646579f-b137-4fd7-84f1-5179640337e5">
<data-source source="LOCAL" name="test" uuid="a5872518-f7b2-4c54-8ab8-96976a3c432a">
<driver-ref>sqlite.xerial</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>

View File

@ -3,13 +3,15 @@
# @Author : 河瞬
# @FileName: login_reg.py
# @Software: PyCharm
from fastapi import HTTPException, Response, Depends, APIRouter
from typing import Optional, Annotated
from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
from fastapi import APIRouter
from sqlmodel import select
from models import Tenant, User, Project
from dependencies import *
from models import Tenant
router = APIRouter()
@ -35,7 +37,7 @@ async def login(response: Response, user_data: dict, session: SessionDep):
# 验证用户名和密码
if not user or user.password != user_data['password']:
raise HTTPException(status_code=401, detail="登录失败,用户名或密码错误")
raise HTTPException(status_code=401, detail="Login failed")
# 生成JWT token
token = create_access_token(data={"id": user.id, "role": user.role, "tanant_id": user.tenant.id})
@ -47,19 +49,3 @@ async def login(response: Response, user_data: dict, session: SessionDep):
session.close()
return {"message": f"Login successful"}
@router.post("/api/s1/register")
async def register(data: dict, session: SessionDep):
if session.exec(select(Tenant).where(Tenant.name == data['name'])).first():
raise HTTPException(status_code=409, detail="租户名已存在")
if session.exec(select(User).where(User.username == data['username'])).first():
raise HTTPException(status_code=409, detail="用户名已存在")
tenant = Tenant(name=data['name'])
user = User(username=data['username'], password=data['password'], role=1, tenant=tenant)
session.add(tenant)
session.add(user)
session.commit()
session.close()
return {"detail": "注册成功"}

View File

@ -2,9 +2,19 @@
# @Time : 2024/11/19 下午8:05
# @FileName: manage_project.py
# @Software: PyCharm
from fastapi import APIRouter
from fastapi import HTTPException, Response, Depends, APIRouter
from typing import Optional, Annotated
from datetime import datetime, timedelta
from jose import JWTError, jwt
from sqlmodel import select
from models import Tenant, User, Project
from dependencies import *
router = APIRouter()
@router.get(...)
def example():
return "hello"

View File

@ -2,7 +2,19 @@
# @Time : 2024/11/19 下午8:04
# @FileName: manage_tanant.py
# @Software: PyCharm
from fastapi import APIRouter
from fastapi import HTTPException, Response, Depends, APIRouter
from typing import Optional, Annotated
from datetime import datetime, timedelta
from jose import JWTError, jwt
from sqlmodel import select
from models import Tenant, User, Project
from dependencies import *
router = APIRouter()
@router.get(...)
def example():
return "hello"

View File

@ -3,76 +3,19 @@
# @Author : 河瞬
# @FileName: manage_user.py
# @Software: PyCharm
from fastapi import HTTPException, APIRouter, Depends, Request
from fastapi import HTTPException, Response, Depends, APIRouter
from typing import Optional, Annotated
from datetime import datetime, timedelta
from jose import JWTError, jwt
from sqlmodel import select
from dependencies import SessionDep, get_current_user
from models import User
from models import Tenant, User, Project
from dependencies import *
router = APIRouter()
# 枚举成员
@router.get("/api/s1/user")
async def list_users(request: Request, session: SessionDep, current_user: User = Depends(get_current_user)):
if current_user.role != 1:
raise HTTPException(status_code=403, detail="Only admin users can list users")
users = session.exec(select(User).where(User.tenant_id == current_user.tenant_id)).all()
user_list = [{"username": user.username, "role": user.role} for user in users]
return user_list
# 新增和修改成员
@router.post("/api/s1/user")
async def add_or_update_user(data: dict, session: SessionDep, current_user: User = Depends(get_current_user)):
if current_user.role != 1:
raise HTTPException(status_code=403, detail="Only admin users can add or update users")
username = data.get("username")
password = data.get("password")
role = data.get("role")
if role not in ["auditor", "estimator"]:
raise HTTPException(status_code=400, detail="Invalid role")
role = 2 if role == "estimator" else 3
if not username or not role:
raise HTTPException(status_code=400, detail="Username and role are required")
user = session.exec(select(User).where(User.username == username, User.tenant_id == current_user.tenant_id)).first()
if user:
if password == "":
user.role = role
else:
user.password = password
user.role = role
session.add(user)
session.commit()
return {"detail": "User updated successfully"}
else:
if password == "":
raise HTTPException(status_code=400, detail="Password is required for new user")
new_user = User(username=username, password=password, role=role, tenant_id=current_user.tenant_id)
session.add(new_user)
session.commit()
return {"detail": "User added successfully"}
# 删除成员
@router.delete("/api/s1/user")
async def delete_user(username: str, session: SessionDep, current_user: User = Depends(get_current_user)):
if current_user.role != 1:
raise HTTPException(status_code=403, detail="Only admin users can delete users")
# username = data.get("username")
if not username:
raise HTTPException(status_code=422, detail="Username is required")
user = session.exec(select(User).where(User.username == username, User.tenant_id == current_user.tenant_id)).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
session.delete(user)
session.commit()
return {"detail": "User deleted successfully"}
@router.get(...)
def example():
return "hello"

View File

@ -4,14 +4,10 @@
# @FileName: dependencies.py
# @Software: PyCharm
from typing import Annotated
from fastapi import Depends, HTTPException, Cookie, Response
from jose import jwt, JWTError
from sqlmodel import Session, select
from config import Settings
from fastapi import Depends
from database import engine
from models import User
from sqlmodel import Session
from config import Settings
def get_session():
@ -26,27 +22,3 @@ def get_settings():
SessionDep = Annotated[Session, Depends(get_session)]
SettingsDep = get_settings()
def get_current_user(response: Response, session_token: Annotated[str | None, Cookie()] = None, db: SessionDep = None,
settings: SettingsDep = SettingsDep):
if not session_token:
response.set_cookie(key="session_token", value="", httponly=True)
raise HTTPException(status_code=401, detail="Not authenticated", )
try:
payload = jwt.decode(session_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
user_id = payload.get("id")
if user_id is None:
response.set_cookie(key="session_token", value="", httponly=True)
raise HTTPException(status_code=401, detail="Invalid token")
except JWTError:
response.set_cookie(key="session_token", value="", httponly=True)
raise HTTPException(status_code=401, detail="Invalid token")
user = db.exec(select(User).where(User.id == user_id)).first()
if not user:
response.set_cookie(key="session_token", value="", httponly=True)
raise HTTPException(status_code=401, detail="User not found")
return user

11
main.py
View File

@ -1,9 +1,16 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException, Response, Depends, APIRouter
from typing import Optional, Annotated
from datetime import datetime, timedelta
from jose import JWTError, jwt
from sqlmodel import Session, select
from database import create_db_and_tables, engine
from models import Tenant, User, Project
from dependencies import *
from api import login_reg, manage_project, manage_tanant, manage_user
from database import create_db_and_tables
# 用于生成和验证JWT的密钥
SECRET_KEY = "your_secret_key"

View File

@ -6,7 +6,7 @@
from datetime import datetime
from typing import List, Optional
from sqlmodel import Field, Relationship, SQLModel
from sqlmodel import Field, Relationship, SQLModel, DateTime
class Tenant(SQLModel, table=True):

View File

@ -1,3 +0,0 @@
ALGORITHM=HS256
DATABASE_URL=sqlite:///test.sqlite3
SECRET_KEY=your_secret_key

View File

@ -1,156 +0,0 @@
import unittest
from fastapi.testclient import TestClient
# from database import engine,create_db_and_tables
from sqlmodel import select, Session, SQLModel, create_engine
from config import Settings
from main import app # 假设你的FastAPI应用实例名为app
from models import User, Tenant
settings = Settings()
engine = create_engine(settings.DATABASE_URL, echo=True)
def create_db_and_tables():
SQLModel.metadata.create_all(engine)
# 创建一个测试客户端
client = TestClient(app)
session: Session = None
class TestLoginReg(unittest.TestCase):
@classmethod
def setUpClass(cls):
global session
create_db_and_tables()
session = Session(engine)
session.close()
session = Session(engine)
@classmethod
def tearDownClass(cls):
session.close()
def setUp(self):
for item in session.exec(select(User)).all():
session.delete(item)
session.commit()
for item in session.exec(select(Tenant)).all():
session.delete(item)
session.commit()
# 创建一个模拟的数据库会话
# session.add()
def tearDown(self):
# 清理数据库
for item in session.exec(select(User)).all():
session.delete(item)
session.commit()
for item in session.exec(select(Tenant)).all():
session.delete(item)
session.commit()
# session.delete()
def test_setup_teardown(self):
pass
def test_login(self):
# 创建测试用户
test_tanant = Tenant(name="testtenant")
session.add(test_tanant)
session.commit()
test_user = User(username="testuser", password="testpassword", role=1, tenant=test_tanant)
session.add(test_user)
session.commit()
# 发送登录请求
response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"message": "Login successful"})
def test_register(self):
# 发送注册请求
response = client.post("/api/s1/register",
json={"name": "regtenant", "username": "reguser", "password": "regpassword"})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"detail": "注册成功"})
response = client.post("/api/s1/register",
json={"name": "regtenant", "username": "rereguser", "password": "regpassword"})
self.assertEqual(response.status_code, 409)
self.assertEqual(response.json(), {"detail": "租户名已存在"})
response = client.post("/api/s1/register",
json={"name": "regretenant", "username": "reguser", "password": "regpassword"})
self.assertEqual(response.status_code, 409)
self.assertEqual(response.json(), {"detail": "用户名已存在"})
def get_cookie(self):
# 创建测试用户
test_tanant = Tenant(name="cookietenant")
session.add(test_tanant)
session.commit()
test_user = User(username="cookieuser", password="cookiepassword", role=1, tenant=test_tanant)
session.add(test_user)
session.commit()
# 发送登录请求以获取cookie
login_response = client.post("/api/s1/login", json={"username": "cookieuser", "password": "cookiepassword"})
self.assertEqual(login_response.status_code, 200)
return login_response.cookies.get("session_token")
def test_list_users(self):
# 发送枚举成员请求
cookie = self.get_cookie()
tenant = session.exec(select(Tenant).where(Tenant.name == "cookietenant")).first()
users = [User(username=f"testuser{i}", password="testpassword", role=2, tenant=tenant) for i in range(5)]
session.add_all(users)
session.commit()
response = client.get("/api/s1/user", cookies={"session_token": cookie})
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.json()), 6)
# 这里可以添加更多的断言来检查返回的用户列表
def test_add_or_update_user(self):
cookie = self.get_cookie()
# 测试新增成员
response = client.post("/api/s1/user",
json={"username": "newuser", "password": "newpassword", "role": "estimator"},
cookies={"session_token": cookie})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"detail": "User added successfully"})
# 测试修改成员
response = client.post("/api/s1/user", json={"username": "newuser", "password": "", "role": "auditor"},
cookies={"session_token": cookie})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"detail": "User updated successfully"})
user = session.exec(select(User).where(User.username == "newuser")).first()
self.assertEqual(user.role, 3)
def test_delete_user(self):
cookie = self.get_cookie()
# 添加一个用户用于测试删除
add_response = client.post("/api/s1/user",
json={"username": "todeleteuser", "password": "deletepassword", "role": "estimator"},
cookies={"session_token": cookie})
self.assertEqual(add_response.status_code, 200)
# 发送删除用户请求
# 错误请求
response = client.delete("/api/s1/user", params={"asdf": "notexistuser"}, cookies={"session_token": cookie})
self.assertEqual(response.status_code, 422)
# 不存在的用户
response = client.delete("/api/s1/user", params={"username": "notexistuser"}, cookies={"session_token": cookie})
self.assertEqual(response.status_code, 404)
# 正确的请求
response = client.delete("/api/s1/user", params={"username": "todeleteuser"}, cookies={"session_token": cookie})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"detail": "User deleted successfully"})
if __name__ == '__main__':
unittest.main()