ExerciseOCR/main.py
heshunme 1ed9822f2f 1.最小化import
2.重构代码结构,独立config类与ConnectionManager类
3.引入toml来管理配置文件
4.引入pydantic的basemodel来简化config类的建构
5.新增信号响应系统,linux上可以响应systemctl reload了
6.修改了部分注释
2024-10-26 15:56:38 +08:00

117 lines
3.2 KiB
Python

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
from fastapi.responses import StreamingResponse, HTMLResponse
from base64 import b64encode
from openai import OpenAI
from os.path import isfile
from sys import platform
from config import Config
from ConnectionManager import ConnectionManager
app = FastAPI()
def signal_handler():
print("Received SIGHUP, reloading config")
global config, client
config.load(config_file)
client = OpenAI(api_key=config.key, base_url=config.base_url)
if platform != 'win32':
import signal
signal.signal(signal.SIGHUP, lambda signum, frame: signal_handler())
def init():
global config
if not isfile(config_file):
config = Config()
config.save(config_file)
config_file = "./config.toml"
init()
config = Config.load(config_file)
client = OpenAI(api_key=config.key, base_url=config.base_url)
manager = ConnectionManager()
# WebSocket端点
@app.websocket("/listener")
async def event(websocket: WebSocket):
await manager.connect(websocket)
print("Client connected")
try:
while True:
# 这里可以添加逻辑处理来自客户端的消息
data = await websocket.receive_text()
print(f"Received message from client: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
# POST请求端点
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# 读取图片文件并转换为base64编码
image_data = await file.read()
image_base64 = b64encode(image_data).decode('utf-8')
# 构造请求给API
response = client.chat.completions.create(
model=config.model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": config.prompt,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
},
],
}
],
stream=True
)
# 广播图片数据
await manager.broadcast_json({"type": "image", "content": image_base64})
# 流式返回结果
async def stream_response():
for chunk in response:
content = chunk.choices[0].delta.content
if content:
yield content
# 同时将内容广播给所有WebSocket连接的客户端
await manager.broadcast_json({"type": "text", "content": content})
return StreamingResponse(stream_response(), media_type="text/plain")
# html页面
@app.get("/terminal", response_class=HTMLResponse)
async def test():
with open("html/terminal.html", "r", encoding="utf-8") as f:
return f.read()
@app.get("/upload", response_class=HTMLResponse)
async def test():
with open("html/upload.html", "r", encoding="utf-8") as f:
return f.read()
@app.get("/", response_class=HTMLResponse)
async def test():
with open("html/index.html", "r", encoding="utf-8") as f:
return f.read()