From 293909ab8a4afcb7bb73e1cebd6f68bb9cd51a77 Mon Sep 17 00:00:00 2001 From: heshunme Date: Tue, 19 Nov 2024 19:28:05 +0800 Subject: [PATCH] =?UTF-8?q?remake=E5=B9=B6=E4=BF=AE=E6=AD=A3=E4=BA=86?= =?UTF-8?q?=E6=96=B0ORM=E7=9A=84=E4=BD=BF=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 46 +++++++++++++++++++++++++++++----------------- models.py | 15 ++++++++++----- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/main.py b/main.py index ba1a1fe..7f5befc 100644 --- a/main.py +++ b/main.py @@ -1,26 +1,34 @@ -from fastapi import FastAPI, HTTPException, Response, Depends -from typing import Optional +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException, Response, Depends, APIRouter +from typing import Optional, Annotated from datetime import datetime, timedelta from jose import JWTError, jwt -from models import * -app = FastAPI() +from sqlmodel import Session, select -# 创建数据库引擎 -engine = create_engine('sqlite:///test.db') - -# 创建所有表 -Base.metadata.create_all(engine) - -# 创建会话 -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -session = SessionLocal() +from database import create_db_and_tables, engine +from models import Tenant, User, Project # 用于生成和验证JWT的密钥 SECRET_KEY = "your_secret_key" ALGORITHM = "HS256" +# @app.on_event("startup") +# def on_startup(): +# create_db_and_tables() +@asynccontextmanager +async def lifespan(app: FastAPI): + create_db_and_tables() + yield + + +def get_session(): + with Session(engine) as session: + yield session + + # 生成JWT token def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() @@ -33,18 +41,22 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): return encoded_jwt +app = FastAPI(lifespan=lifespan) +SessionDep = Annotated[Session, Depends(get_session)] + + # 登录路由 @app.post("/api/s1/login") -async def login(response: Response, user_data: dict): +async def login(response: Response, user_data: dict, session: SessionDep): # 查询用户 - user = session.query(User).filter(User.name == user_data['username']).first() + user = session.exec(select(User).where(User.username == user_data['username'])).first() # 验证用户名和密码 if not user or user.password != user_data['password']: raise HTTPException(status_code=401, detail="Login failed") # 生成JWT token - token = create_access_token(data={"sub": user.name}) + token = create_access_token(data={"id": user.id, "role": user.role, "tanant_id": user.tenant.id}) # 设置cookie response.set_cookie(key="session_token", value=token, httponly=True) @@ -52,4 +64,4 @@ async def login(response: Response, user_data: dict): # 关闭数据库会话 session.close() - return {"message": "Login successful"} + return {"message": f"Login successful"} diff --git a/models.py b/models.py index d39af4b..a9fb708 100644 --- a/models.py +++ b/models.py @@ -17,6 +17,12 @@ class Tenant(SQLModel, table=True): projects: List["Project"] = Relationship(back_populates="owner") +class ProjectUserLink(SQLModel, table=True): + __tablename__ = 'ProjectUserLink' + project_id: int | None = Field(default=None, foreign_key="Project.id", primary_key=True) + user_id: int | None = Field(default=None, foreign_key="User.id", primary_key=True) + + class User(SQLModel, table=True): __tablename__ = 'User' id: Optional[int] = Field(default=None, primary_key=True) @@ -25,7 +31,7 @@ class User(SQLModel, table=True): role: int tenant_id: int = Field(default=None, foreign_key="Tenant.id") tenant: Tenant = Relationship(back_populates="users") - projects: List["Project"] = Relationship(back_populates="estimators") + projects: List["Project"] = Relationship(back_populates="users", link_model=ProjectUserLink) class Project(SQLModel, table=True): @@ -35,7 +41,6 @@ class Project(SQLModel, table=True): requirement: str owner_id: int = Field(default=None, foreign_key="Tenant.id") owner: Tenant = Relationship(back_populates="projects") - start_time: DateTime = Field(default=datetime.utcnow) - deadline: DateTime - estimators: List["User"] = Relationship(back_populates="projects") - auditors: List["User"] = Relationship(back_populates="projects") + start_time: datetime = Field(default=datetime.utcnow) + deadline: datetime + users: List["User"] = Relationship(back_populates="projects", link_model=ProjectUserLink)