2.重构代码结构,独立config类与ConnectionManager类 3.引入toml来管理配置文件 4.引入pydantic的basemodel来简化config类的建构 5.新增信号响应系统,linux上可以响应systemctl reload了 6.修改了部分注释
117 lines
3.2 KiB
Python
117 lines
3.2 KiB
Python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
|
|
from fastapi.responses import StreamingResponse, HTMLResponse
|
|
from base64 import b64encode
|
|
from openai import OpenAI
|
|
from os.path import isfile
|
|
from sys import platform
|
|
from config import Config
|
|
from ConnectionManager import ConnectionManager
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
def signal_handler():
|
|
print("Received SIGHUP, reloading config")
|
|
global config, client
|
|
config.load(config_file)
|
|
client = OpenAI(api_key=config.key, base_url=config.base_url)
|
|
|
|
|
|
if platform != 'win32':
|
|
import signal
|
|
|
|
signal.signal(signal.SIGHUP, lambda signum, frame: signal_handler())
|
|
|
|
|
|
def init():
|
|
global config
|
|
if not isfile(config_file):
|
|
config = Config()
|
|
config.save(config_file)
|
|
|
|
|
|
config_file = "./config.toml"
|
|
init()
|
|
config = Config.load(config_file)
|
|
|
|
client = OpenAI(api_key=config.key, base_url=config.base_url)
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
# WebSocket端点
|
|
@app.websocket("/listener")
|
|
async def event(websocket: WebSocket):
|
|
await manager.connect(websocket)
|
|
print("Client connected")
|
|
try:
|
|
while True:
|
|
# 这里可以添加逻辑处理来自客户端的消息
|
|
data = await websocket.receive_text()
|
|
print(f"Received message from client: {data}")
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(websocket)
|
|
|
|
|
|
# POST请求端点
|
|
@app.post("/predict")
|
|
async def predict(file: UploadFile = File(...)):
|
|
# 读取图片文件并转换为base64编码
|
|
image_data = await file.read()
|
|
image_base64 = b64encode(image_data).decode('utf-8')
|
|
|
|
# 构造请求给API
|
|
response = client.chat.completions.create(
|
|
model=config.model,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": config.prompt,
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
},
|
|
},
|
|
],
|
|
}
|
|
],
|
|
stream=True
|
|
)
|
|
# 广播图片数据
|
|
await manager.broadcast_json({"type": "image", "content": image_base64})
|
|
|
|
# 流式返回结果
|
|
async def stream_response():
|
|
for chunk in response:
|
|
content = chunk.choices[0].delta.content
|
|
if content:
|
|
yield content
|
|
# 同时将内容广播给所有WebSocket连接的客户端
|
|
await manager.broadcast_json({"type": "text", "content": content})
|
|
|
|
return StreamingResponse(stream_response(), media_type="text/plain")
|
|
|
|
|
|
# html页面
|
|
@app.get("/terminal", response_class=HTMLResponse)
|
|
async def test():
|
|
with open("html/terminal.html", "r", encoding="utf-8") as f:
|
|
return f.read()
|
|
|
|
|
|
@app.get("/upload", response_class=HTMLResponse)
|
|
async def test():
|
|
with open("html/upload.html", "r", encoding="utf-8") as f:
|
|
return f.read()
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def test():
|
|
with open("html/index.html", "r", encoding="utf-8") as f:
|
|
return f.read()
|