diff --git a/api/manage_project.py b/api/manage_project.py index 36793b7..c6426f6 100644 --- a/api/manage_project.py +++ b/api/manage_project.py @@ -2,15 +2,7 @@ # @Time : 2024/11/19 下午8:05 # @FileName: manage_project.py # @Software: PyCharm -from fastapi import HTTPException, Response, Depends, APIRouter -from typing import Optional, Annotated -from datetime import datetime, timedelta -from jose import JWTError, jwt - -from sqlmodel import select - -from models import Tenant, User, Project -from dependencies import * +from fastapi import APIRouter router = APIRouter() diff --git a/api/manage_tanant.py b/api/manage_tanant.py index 9cb191e..04404e9 100644 --- a/api/manage_tanant.py +++ b/api/manage_tanant.py @@ -2,15 +2,7 @@ # @Time : 2024/11/19 下午8:04 # @FileName: manage_tanant.py # @Software: PyCharm -from fastapi import HTTPException, Response, Depends, APIRouter -from typing import Optional, Annotated -from datetime import datetime, timedelta -from jose import JWTError, jwt - -from sqlmodel import select - -from models import Tenant, User, Project -from dependencies import * +from fastapi import APIRouter router = APIRouter() diff --git a/dependencies.py b/dependencies.py index f6c480d..ff15377 100644 --- a/dependencies.py +++ b/dependencies.py @@ -4,11 +4,13 @@ # @FileName: dependencies.py # @Software: PyCharm from typing import Annotated -from fastapi import Depends, Request, HTTPException, Cookie, Response + +from fastapi import Depends, HTTPException, Cookie, Response from jose import jwt, JWTError -from database import engine from sqlmodel import Session, select + from config import Settings +from database import engine from models import User diff --git a/main.py b/main.py index cdd3238..85d608e 100644 --- a/main.py +++ b/main.py @@ -1,16 +1,9 @@ from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException, Response, Depends, APIRouter -from typing import Optional, Annotated -from datetime import datetime, timedelta -from jose import JWTError, jwt +from fastapi import FastAPI -from sqlmodel import Session, select - -from database import create_db_and_tables, engine -from models import Tenant, User, Project -from dependencies import * from api import login_reg, manage_project, manage_tanant, manage_user +from database import create_db_and_tables # 用于生成和验证JWT的密钥 SECRET_KEY = "your_secret_key" diff --git a/models.py b/models.py index a9fb708..ba900b4 100644 --- a/models.py +++ b/models.py @@ -6,7 +6,7 @@ from datetime import datetime from typing import List, Optional -from sqlmodel import Field, Relationship, SQLModel, DateTime +from sqlmodel import Field, Relationship, SQLModel class Tenant(SQLModel, table=True): diff --git a/test/test_login_reg.py b/test/test_login_reg.py index d21c3c7..bd689fe 100644 --- a/test/test_login_reg.py +++ b/test/test_login_reg.py @@ -1,25 +1,15 @@ import unittest -from main import app # 假设你的FastAPI应用实例名为app -from fastapi import Depends + from fastapi.testclient import TestClient -from models import User, Tenant # from database import engine,create_db_and_tables -from sqlmodel import Session +from sqlmodel import select, Session, SQLModel, create_engine -from sqlmodel import Field, Session, SQLModel, create_engine +from config import Settings +from main import app # 假设你的FastAPI应用实例名为app +from models import User, Tenant - -class Hero(SQLModel, table=True): - id: int | None = Field(default=None, primary_key=True) - name: str - secret_name: str - age: int | None = None - - -sqlite_file_name = ":memory:" -sqlite_url = f"sqlite:///{sqlite_file_name}" - -engine = create_engine(sqlite_url, echo=True) +settings = Settings() +engine = create_engine(settings.DATABASE_URL, echo=True) def create_db_and_tables(): @@ -28,27 +18,53 @@ def create_db_and_tables(): # 创建一个测试客户端 client = TestClient(app) +session: Session = None class TestLoginReg(unittest.TestCase): @classmethod def setUpClass(cls): + global session create_db_and_tables() + session = Session(engine) + session.close() + session = Session(engine) + + @classmethod + def tearDownClass(cls): + session.close() def setUp(self): + for item in session.exec(select(User)).all(): + session.delete(item) + session.commit() + for item in session.exec(select(Tenant)).all(): + session.delete(item) + session.commit() # 创建一个模拟的数据库会话 - self.session = Session(engine) + # session.add() def tearDown(self): # 清理数据库 - # self.session.delete() - self.session.close() + for item in session.exec(select(User)).all(): + session.delete(item) + session.commit() + for item in session.exec(select(Tenant)).all(): + session.delete(item) + session.commit() + # session.delete() + + def test_setup_teardown(self): + pass def test_login(self): # 创建测试用户 - test_user = User(username="testuser", password="testpassword", role=1) - self.session.add(test_user) - self.session.commit() + test_tanant = Tenant(name="testtenant") + session.add(test_tanant) + session.commit() + test_user = User(username="testuser", password="testpassword", role=1, tenant=test_tanant) + session.add(test_user) + session.commit() # 发送登录请求 response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"}) @@ -58,37 +74,48 @@ class TestLoginReg(unittest.TestCase): def test_register(self): # 发送注册请求 response = client.post("/api/s1/register", - json={"name": "testtenant", "username": "newuser", "password": "newpassword"}) + json={"name": "regtenant", "username": "reguser", "password": "regpassword"}) self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), {"detail": "注册成功"}) - def test_list_users(self): + response = client.post("/api/s1/register", + json={"name": "regtenant", "username": "rereguser", "password": "regpassword"}) + self.assertEqual(response.status_code, 409) + self.assertEqual(response.json(), {"detail": "租户名已存在"}) + + response = client.post("/api/s1/register", + json={"name": "regretenant", "username": "reguser", "password": "regpassword"}) + self.assertEqual(response.status_code, 409) + self.assertEqual(response.json(), {"detail": "用户名已存在"}) + + def get_cookie(self): # 创建测试用户 - test_user = User(username="testuser", password="testpassword", role=1) - self.session.add(test_user) - self.session.commit() + test_tanant = Tenant(name="cookietenant") + session.add(test_tanant) + session.commit() + test_user = User(username="cookieuser", password="cookiepassword", role=1, tenant=test_tanant) + session.add(test_user) + session.commit() # 发送登录请求以获取cookie - login_response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"}) + login_response = client.post("/api/s1/login", json={"username": "cookieuser", "password": "cookiepassword"}) self.assertEqual(login_response.status_code, 200) - cookie = login_response.cookies.get("session_token") + return login_response.cookies.get("session_token") + def test_list_users(self): # 发送枚举成员请求 + cookie = self.get_cookie() + tenant = session.exec(select(Tenant).where(Tenant.name == "cookietenant")).first() + users = [User(username=f"testuser{i}", password="testpassword", role=2, tenant=tenant) for i in range(5)] + session.add_all(users) + session.commit() response = client.get("/api/s1/user", cookies={"session_token": cookie}) self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.json()), 6) # 这里可以添加更多的断言来检查返回的用户列表 def test_add_or_update_user(self): - # 创建测试用户 - test_user = User(username="testuser", password="testpassword", role=1) - self.session.add(test_user) - self.session.commit() - - # 发送登录请求以获取cookie - login_response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"}) - self.assertEqual(login_response.status_code, 200) - cookie = login_response.cookies.get("session_token") - + cookie = self.get_cookie() # 测试新增成员 response = client.post("/api/s1/user", json={"username": "newuser", "password": "newpassword", "role": "estimator"}, @@ -101,18 +128,11 @@ class TestLoginReg(unittest.TestCase): cookies={"session_token": cookie}) self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), {"detail": "User updated successfully"}) + user = session.exec(select(User).where(User.username == "newuser")).first() + self.assertEqual(user.role, 3) def test_delete_user(self): - # 创建测试用户 - test_user = User(username="testuser", password="testpassword", role=1) - self.session.add(test_user) - self.session.commit() - - # 发送登录请求以获取cookie - login_response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"}) - self.assertEqual(login_response.status_code, 200) - cookie = login_response.cookies.get("session_token") - + cookie = self.get_cookie() # 添加一个用户用于测试删除 add_response = client.post("/api/s1/user", json={"username": "todeleteuser", "password": "deletepassword", "role": "estimator"}, @@ -120,7 +140,14 @@ class TestLoginReg(unittest.TestCase): self.assertEqual(add_response.status_code, 200) # 发送删除用户请求 - response = client.delete("/api/s1/user", json={"username": "todeleteuser"}, cookies={"session_token": cookie}) + # 错误请求 + response = client.delete("/api/s1/user", params={"asdf": "notexistuser"}, cookies={"session_token": cookie}) + self.assertEqual(response.status_code, 422) + # 不存在的用户 + response = client.delete("/api/s1/user", params={"username": "notexistuser"}, cookies={"session_token": cookie}) + self.assertEqual(response.status_code, 404) + # 正确的请求 + response = client.delete("/api/s1/user", params={"username": "todeleteuser"}, cookies={"session_token": cookie}) self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), {"detail": "User deleted successfully"})