优化import
This commit is contained in:
parent
de10d61129
commit
60231d5272
@ -2,15 +2,7 @@
|
|||||||
# @Time : 2024/11/19 下午8:05
|
# @Time : 2024/11/19 下午8:05
|
||||||
# @FileName: manage_project.py
|
# @FileName: manage_project.py
|
||||||
# @Software: PyCharm
|
# @Software: PyCharm
|
||||||
from fastapi import HTTPException, Response, Depends, APIRouter
|
from fastapi import APIRouter
|
||||||
from typing import Optional, Annotated
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from jose import JWTError, jwt
|
|
||||||
|
|
||||||
from sqlmodel import select
|
|
||||||
|
|
||||||
from models import Tenant, User, Project
|
|
||||||
from dependencies import *
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|||||||
@ -2,15 +2,7 @@
|
|||||||
# @Time : 2024/11/19 下午8:04
|
# @Time : 2024/11/19 下午8:04
|
||||||
# @FileName: manage_tanant.py
|
# @FileName: manage_tanant.py
|
||||||
# @Software: PyCharm
|
# @Software: PyCharm
|
||||||
from fastapi import HTTPException, Response, Depends, APIRouter
|
from fastapi import APIRouter
|
||||||
from typing import Optional, Annotated
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from jose import JWTError, jwt
|
|
||||||
|
|
||||||
from sqlmodel import select
|
|
||||||
|
|
||||||
from models import Tenant, User, Project
|
|
||||||
from dependencies import *
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|||||||
@ -4,11 +4,13 @@
|
|||||||
# @FileName: dependencies.py
|
# @FileName: dependencies.py
|
||||||
# @Software: PyCharm
|
# @Software: PyCharm
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi import Depends, Request, HTTPException, Cookie, Response
|
|
||||||
|
from fastapi import Depends, HTTPException, Cookie, Response
|
||||||
from jose import jwt, JWTError
|
from jose import jwt, JWTError
|
||||||
from database import engine
|
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
from config import Settings
|
from config import Settings
|
||||||
|
from database import engine
|
||||||
from models import User
|
from models import User
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
11
main.py
11
main.py
@ -1,16 +1,9 @@
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Response, Depends, APIRouter
|
from fastapi import FastAPI
|
||||||
from typing import Optional, Annotated
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from jose import JWTError, jwt
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from database import create_db_and_tables, engine
|
|
||||||
from models import Tenant, User, Project
|
|
||||||
from dependencies import *
|
|
||||||
from api import login_reg, manage_project, manage_tanant, manage_user
|
from api import login_reg, manage_project, manage_tanant, manage_user
|
||||||
|
from database import create_db_and_tables
|
||||||
|
|
||||||
# 用于生成和验证JWT的密钥
|
# 用于生成和验证JWT的密钥
|
||||||
SECRET_KEY = "your_secret_key"
|
SECRET_KEY = "your_secret_key"
|
||||||
|
|||||||
@ -6,7 +6,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from sqlmodel import Field, Relationship, SQLModel, DateTime
|
from sqlmodel import Field, Relationship, SQLModel
|
||||||
|
|
||||||
|
|
||||||
class Tenant(SQLModel, table=True):
|
class Tenant(SQLModel, table=True):
|
||||||
|
|||||||
@ -1,25 +1,15 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from main import app # 假设你的FastAPI应用实例名为app
|
|
||||||
from fastapi import Depends
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from models import User, Tenant
|
|
||||||
# from database import engine,create_db_and_tables
|
# from database import engine,create_db_and_tables
|
||||||
from sqlmodel import Session
|
from sqlmodel import select, Session, SQLModel, create_engine
|
||||||
|
|
||||||
from sqlmodel import Field, Session, SQLModel, create_engine
|
from config import Settings
|
||||||
|
from main import app # 假设你的FastAPI应用实例名为app
|
||||||
|
from models import User, Tenant
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
class Hero(SQLModel, table=True):
|
engine = create_engine(settings.DATABASE_URL, echo=True)
|
||||||
id: int | None = Field(default=None, primary_key=True)
|
|
||||||
name: str
|
|
||||||
secret_name: str
|
|
||||||
age: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
sqlite_file_name = ":memory:"
|
|
||||||
sqlite_url = f"sqlite:///{sqlite_file_name}"
|
|
||||||
|
|
||||||
engine = create_engine(sqlite_url, echo=True)
|
|
||||||
|
|
||||||
|
|
||||||
def create_db_and_tables():
|
def create_db_and_tables():
|
||||||
@ -28,27 +18,53 @@ def create_db_and_tables():
|
|||||||
|
|
||||||
# 创建一个测试客户端
|
# 创建一个测试客户端
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
session: Session = None
|
||||||
|
|
||||||
|
|
||||||
class TestLoginReg(unittest.TestCase):
|
class TestLoginReg(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
global session
|
||||||
create_db_and_tables()
|
create_db_and_tables()
|
||||||
|
session = Session(engine)
|
||||||
|
session.close()
|
||||||
|
session = Session(engine)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
session.close()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
for item in session.exec(select(User)).all():
|
||||||
|
session.delete(item)
|
||||||
|
session.commit()
|
||||||
|
for item in session.exec(select(Tenant)).all():
|
||||||
|
session.delete(item)
|
||||||
|
session.commit()
|
||||||
# 创建一个模拟的数据库会话
|
# 创建一个模拟的数据库会话
|
||||||
self.session = Session(engine)
|
# session.add()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# 清理数据库
|
# 清理数据库
|
||||||
# self.session.delete()
|
for item in session.exec(select(User)).all():
|
||||||
self.session.close()
|
session.delete(item)
|
||||||
|
session.commit()
|
||||||
|
for item in session.exec(select(Tenant)).all():
|
||||||
|
session.delete(item)
|
||||||
|
session.commit()
|
||||||
|
# session.delete()
|
||||||
|
|
||||||
|
def test_setup_teardown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_login(self):
|
def test_login(self):
|
||||||
# 创建测试用户
|
# 创建测试用户
|
||||||
test_user = User(username="testuser", password="testpassword", role=1)
|
test_tanant = Tenant(name="testtenant")
|
||||||
self.session.add(test_user)
|
session.add(test_tanant)
|
||||||
self.session.commit()
|
session.commit()
|
||||||
|
test_user = User(username="testuser", password="testpassword", role=1, tenant=test_tanant)
|
||||||
|
session.add(test_user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
# 发送登录请求
|
# 发送登录请求
|
||||||
response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"})
|
response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"})
|
||||||
@ -58,37 +74,48 @@ class TestLoginReg(unittest.TestCase):
|
|||||||
def test_register(self):
|
def test_register(self):
|
||||||
# 发送注册请求
|
# 发送注册请求
|
||||||
response = client.post("/api/s1/register",
|
response = client.post("/api/s1/register",
|
||||||
json={"name": "testtenant", "username": "newuser", "password": "newpassword"})
|
json={"name": "regtenant", "username": "reguser", "password": "regpassword"})
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertEqual(response.json(), {"detail": "注册成功"})
|
self.assertEqual(response.json(), {"detail": "注册成功"})
|
||||||
|
|
||||||
def test_list_users(self):
|
response = client.post("/api/s1/register",
|
||||||
|
json={"name": "regtenant", "username": "rereguser", "password": "regpassword"})
|
||||||
|
self.assertEqual(response.status_code, 409)
|
||||||
|
self.assertEqual(response.json(), {"detail": "租户名已存在"})
|
||||||
|
|
||||||
|
response = client.post("/api/s1/register",
|
||||||
|
json={"name": "regretenant", "username": "reguser", "password": "regpassword"})
|
||||||
|
self.assertEqual(response.status_code, 409)
|
||||||
|
self.assertEqual(response.json(), {"detail": "用户名已存在"})
|
||||||
|
|
||||||
|
def get_cookie(self):
|
||||||
# 创建测试用户
|
# 创建测试用户
|
||||||
test_user = User(username="testuser", password="testpassword", role=1)
|
test_tanant = Tenant(name="cookietenant")
|
||||||
self.session.add(test_user)
|
session.add(test_tanant)
|
||||||
self.session.commit()
|
session.commit()
|
||||||
|
test_user = User(username="cookieuser", password="cookiepassword", role=1, tenant=test_tanant)
|
||||||
|
session.add(test_user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
# 发送登录请求以获取cookie
|
# 发送登录请求以获取cookie
|
||||||
login_response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"})
|
login_response = client.post("/api/s1/login", json={"username": "cookieuser", "password": "cookiepassword"})
|
||||||
self.assertEqual(login_response.status_code, 200)
|
self.assertEqual(login_response.status_code, 200)
|
||||||
cookie = login_response.cookies.get("session_token")
|
return login_response.cookies.get("session_token")
|
||||||
|
|
||||||
|
def test_list_users(self):
|
||||||
# 发送枚举成员请求
|
# 发送枚举成员请求
|
||||||
|
cookie = self.get_cookie()
|
||||||
|
tenant = session.exec(select(Tenant).where(Tenant.name == "cookietenant")).first()
|
||||||
|
users = [User(username=f"testuser{i}", password="testpassword", role=2, tenant=tenant) for i in range(5)]
|
||||||
|
session.add_all(users)
|
||||||
|
session.commit()
|
||||||
response = client.get("/api/s1/user", cookies={"session_token": cookie})
|
response = client.get("/api/s1/user", cookies={"session_token": cookie})
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertEqual(len(response.json()), 6)
|
||||||
# 这里可以添加更多的断言来检查返回的用户列表
|
# 这里可以添加更多的断言来检查返回的用户列表
|
||||||
|
|
||||||
def test_add_or_update_user(self):
|
def test_add_or_update_user(self):
|
||||||
# 创建测试用户
|
cookie = self.get_cookie()
|
||||||
test_user = User(username="testuser", password="testpassword", role=1)
|
|
||||||
self.session.add(test_user)
|
|
||||||
self.session.commit()
|
|
||||||
|
|
||||||
# 发送登录请求以获取cookie
|
|
||||||
login_response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"})
|
|
||||||
self.assertEqual(login_response.status_code, 200)
|
|
||||||
cookie = login_response.cookies.get("session_token")
|
|
||||||
|
|
||||||
# 测试新增成员
|
# 测试新增成员
|
||||||
response = client.post("/api/s1/user",
|
response = client.post("/api/s1/user",
|
||||||
json={"username": "newuser", "password": "newpassword", "role": "estimator"},
|
json={"username": "newuser", "password": "newpassword", "role": "estimator"},
|
||||||
@ -101,18 +128,11 @@ class TestLoginReg(unittest.TestCase):
|
|||||||
cookies={"session_token": cookie})
|
cookies={"session_token": cookie})
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertEqual(response.json(), {"detail": "User updated successfully"})
|
self.assertEqual(response.json(), {"detail": "User updated successfully"})
|
||||||
|
user = session.exec(select(User).where(User.username == "newuser")).first()
|
||||||
|
self.assertEqual(user.role, 3)
|
||||||
|
|
||||||
def test_delete_user(self):
|
def test_delete_user(self):
|
||||||
# 创建测试用户
|
cookie = self.get_cookie()
|
||||||
test_user = User(username="testuser", password="testpassword", role=1)
|
|
||||||
self.session.add(test_user)
|
|
||||||
self.session.commit()
|
|
||||||
|
|
||||||
# 发送登录请求以获取cookie
|
|
||||||
login_response = client.post("/api/s1/login", json={"username": "testuser", "password": "testpassword"})
|
|
||||||
self.assertEqual(login_response.status_code, 200)
|
|
||||||
cookie = login_response.cookies.get("session_token")
|
|
||||||
|
|
||||||
# 添加一个用户用于测试删除
|
# 添加一个用户用于测试删除
|
||||||
add_response = client.post("/api/s1/user",
|
add_response = client.post("/api/s1/user",
|
||||||
json={"username": "todeleteuser", "password": "deletepassword", "role": "estimator"},
|
json={"username": "todeleteuser", "password": "deletepassword", "role": "estimator"},
|
||||||
@ -120,7 +140,14 @@ class TestLoginReg(unittest.TestCase):
|
|||||||
self.assertEqual(add_response.status_code, 200)
|
self.assertEqual(add_response.status_code, 200)
|
||||||
|
|
||||||
# 发送删除用户请求
|
# 发送删除用户请求
|
||||||
response = client.delete("/api/s1/user", json={"username": "todeleteuser"}, cookies={"session_token": cookie})
|
# 错误请求
|
||||||
|
response = client.delete("/api/s1/user", params={"asdf": "notexistuser"}, cookies={"session_token": cookie})
|
||||||
|
self.assertEqual(response.status_code, 422)
|
||||||
|
# 不存在的用户
|
||||||
|
response = client.delete("/api/s1/user", params={"username": "notexistuser"}, cookies={"session_token": cookie})
|
||||||
|
self.assertEqual(response.status_code, 404)
|
||||||
|
# 正确的请求
|
||||||
|
response = client.delete("/api/s1/user", params={"username": "todeleteuser"}, cookies={"session_token": cookie})
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertEqual(response.json(), {"detail": "User deleted successfully"})
|
self.assertEqual(response.json(), {"detail": "User deleted successfully"})
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user