1.最小化import

2.重构代码结构,独立config类与ConnectionManager类
3.引入toml来管理配置文件
4.引入pydantic的basemodel来简化config类的建构
5.新增信号响应系统,linux上可以响应systemctl reload了
6.修改了部分注释
This commit is contained in:
高子兴 2024-10-26 15:56:38 +08:00
parent 97eef02a92
commit 1ed9822f2f
4 changed files with 85 additions and 48 deletions

29
ConnectionManager.py Normal file
View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# @Time : 2024/10/26 下午3:35
# @Author : 河瞬
# @FileName: ConnectionManager.py
# @Software: PyCharm
# @Github :
from typing import List
from fastapi import WebSocket
# WebSocket连接管理器
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
@staticmethod
async def send_personal_message(message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast_json(self, data: dict):
for connection in self.active_connections:
await connection.send_json(data)

24
config.py Normal file
View File

@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
# @Time : 2024/10/26 下午3:25
# @Author : 河瞬
# @FileName: config.py
# @Software: PyCharm
import toml
from pydantic import BaseModel
class Config(BaseModel):
key: str = ""
prompt: str = ""
model: str = "chatgpt-4o-latest"
base_url: str = ""
def save(self, config_file):
with open(config_file, "w", encoding='utf-8') as f:
toml.dump(self.model_dump(), f)
@classmethod
def load(cls, config_file):
with open(config_file, "r", encoding='utf-8') as f:
config_data = toml.load(f)
return cls(**config_data)

75
main.py
View File

@ -1,60 +1,40 @@
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
from fastapi.responses import StreamingResponse, HTMLResponse from fastapi.responses import StreamingResponse, HTMLResponse
from typing import List from base64 import b64encode
import base64 from openai import OpenAI
import openai from os.path import isfile
import signal from sys import platform
from config import Config
from ConnectionManager import ConnectionManager
app = FastAPI() app = FastAPI()
key = ""
prompt = ""
def get_key(): def signal_handler():
with open("key", "r") as f: print("Received SIGHUP, reloading config")
k = f.read() global config, client
return k config.load(config_file)
client = OpenAI(api_key=config.key, base_url=config.base_url)
def get_prompt(): if platform != 'win32':
with open("prompt", "r", encoding="utf-8") as f: import signal
p = f.read()
return p signal.signal(signal.SIGHUP, lambda signum, frame: signal_handler())
def load_config(): def init():
global key, prompt global config
key = get_key() if not isfile(config_file):
prompt = get_prompt() config = Config()
config.save(config_file)
signal.signal(signal.SIGHUP, load_config()) config_file = "./config.toml"
init()
load_config() config = Config.load(config_file)
client = openai.OpenAI(api_key=key,
base_url="https://open.bigmodel.cn/api/paas/v4/")
# WebSocket连接管理器
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
@staticmethod
async def send_personal_message(message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast_json(self, data: dict):
for connection in self.active_connections:
await connection.send_json(data)
client = OpenAI(api_key=config.key, base_url=config.base_url)
manager = ConnectionManager() manager = ConnectionManager()
@ -78,18 +58,18 @@ async def event(websocket: WebSocket):
async def predict(file: UploadFile = File(...)): async def predict(file: UploadFile = File(...)):
# 读取图片文件并转换为base64编码 # 读取图片文件并转换为base64编码
image_data = await file.read() image_data = await file.read()
image_base64 = base64.b64encode(image_data).decode('utf-8') image_base64 = b64encode(image_data).decode('utf-8')
# 构造请求给API # 构造请求给API
response = client.chat.completions.create( response = client.chat.completions.create(
model="glm-4v", model=config.model,
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": get_prompt(), "text": config.prompt,
}, },
{ {
"type": "image_url", "type": "image_url",
@ -117,6 +97,7 @@ async def predict(file: UploadFile = File(...)):
return StreamingResponse(stream_response(), media_type="text/plain") return StreamingResponse(stream_response(), media_type="text/plain")
# html页面
@app.get("/terminal", response_class=HTMLResponse) @app.get("/terminal", response_class=HTMLResponse)
async def test(): async def test():
with open("html/terminal.html", "r", encoding="utf-8") as f: with open("html/terminal.html", "r", encoding="utf-8") as f:

View File

@ -3,3 +3,6 @@ openai~=1.52.2
fastapi~=0.115.3 fastapi~=0.115.3
uvicorn~=0.20.0 uvicorn~=0.20.0
python-multipart~=0.0.12 python-multipart~=0.0.12
requests~=2.27.1
toml~=0.10.2
pydantic~=2.9.2