ExerciseOCR/main.py

127 lines
3.4 KiB
Python
Raw Permalink Normal View History

2024-10-25 14:53:08 +00:00
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
2024-10-26 09:26:52 +00:00
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from base64 import b64encode
from openai import OpenAI
from os.path import isfile
from sys import platform
from config import Config
from ConnectionManager import ConnectionManager
2024-10-25 12:53:51 +00:00
app = FastAPI()
def signal_handler():
print("Received SIGHUP, reloading config")
global config, client
config.load(config_file)
client = OpenAI(api_key=config.key, base_url=config.base_url)
if platform != 'win32':
import signal
signal.signal(signal.SIGHUP, lambda signum, frame: signal_handler())
def init():
global config
if not isfile(config_file):
config = Config()
config.save(config_file)
config_file = "./config.toml"
init()
config = Config.load(config_file)
2024-10-25 12:53:51 +00:00
client = OpenAI(api_key=config.key, base_url=config.base_url)
2024-10-25 12:53:51 +00:00
manager = ConnectionManager()
# WebSocket端点
2024-10-25 14:53:08 +00:00
@app.websocket("/listener")
2024-10-25 12:53:51 +00:00
async def event(websocket: WebSocket):
await manager.connect(websocket)
2024-10-25 14:53:08 +00:00
print("Client connected")
2024-10-25 12:53:51 +00:00
try:
while True:
# 这里可以添加逻辑处理来自客户端的消息
data = await websocket.receive_text()
print(f"Received message from client: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
# POST请求端点
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# 读取图片文件并转换为base64编码
image_data = await file.read()
image_base64 = b64encode(image_data).decode('utf-8')
2024-10-25 12:53:51 +00:00
# 构造请求给API
response = client.chat.completions.create(
model=config.model,
2024-10-25 12:53:51 +00:00
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": config.prompt,
2024-10-25 12:53:51 +00:00
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
},
],
}
],
stream=True
)
# 广播图片数据
await manager.broadcast_json({"type": "image", "content": image_base64})
# 流式返回结果
async def stream_response():
for chunk in response:
content = chunk.choices[0].delta.content
if content:
yield content
# 同时将内容广播给所有WebSocket连接的客户端
await manager.broadcast_json({"type": "text", "content": content})
return StreamingResponse(stream_response(), media_type="text/plain")
# html页面
2024-10-25 14:53:08 +00:00
@app.get("/terminal", response_class=HTMLResponse)
async def test():
with open("html/terminal.html", "r", encoding="utf-8") as f:
return f.read()
@app.get("/upload", response_class=HTMLResponse)
async def test():
with open("html/upload.html", "r", encoding="utf-8") as f:
return f.read()
@app.get("/", response_class=HTMLResponse)
2024-10-25 12:53:51 +00:00
async def test():
2024-10-25 14:53:08 +00:00
with open("html/index.html", "r", encoding="utf-8") as f:
2024-10-25 12:53:51 +00:00
return f.read()
2024-10-26 09:26:52 +00:00
@app.get("/favicon.ico", response_class=FileResponse)
async def favicon():
return FileResponse("html/favicon.ico")
@app.get("/static/{path:path}", response_class=FileResponse)
async def static(path: str):
return FileResponse(f"html/static/{path}")