ExerciseOCR/main.py

117 lines
3.2 KiB
Python
Raw Normal View History

2024-10-25 14:53:08 +00:00
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
2024-10-25 12:53:51 +00:00
from fastapi.responses import StreamingResponse, HTMLResponse
from base64 import b64encode
from openai import OpenAI
from os.path import isfile
from sys import platform
from config import Config
from ConnectionManager import ConnectionManager
2024-10-25 12:53:51 +00:00
app = FastAPI()
def signal_handler():
print("Received SIGHUP, reloading config")
global config, client
config.load(config_file)
client = OpenAI(api_key=config.key, base_url=config.base_url)
if platform != 'win32':
import signal
signal.signal(signal.SIGHUP, lambda signum, frame: signal_handler())
def init():
global config
if not isfile(config_file):
config = Config()
config.save(config_file)
config_file = "./config.toml"
init()
config = Config.load(config_file)
2024-10-25 12:53:51 +00:00
client = OpenAI(api_key=config.key, base_url=config.base_url)
2024-10-25 12:53:51 +00:00
manager = ConnectionManager()
# WebSocket端点
2024-10-25 14:53:08 +00:00
@app.websocket("/listener")
2024-10-25 12:53:51 +00:00
async def event(websocket: WebSocket):
await manager.connect(websocket)
2024-10-25 14:53:08 +00:00
print("Client connected")
2024-10-25 12:53:51 +00:00
try:
while True:
# 这里可以添加逻辑处理来自客户端的消息
data = await websocket.receive_text()
print(f"Received message from client: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
# POST请求端点
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# 读取图片文件并转换为base64编码
image_data = await file.read()
image_base64 = b64encode(image_data).decode('utf-8')
2024-10-25 12:53:51 +00:00
# 构造请求给API
response = client.chat.completions.create(
model=config.model,
2024-10-25 12:53:51 +00:00
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": config.prompt,
2024-10-25 12:53:51 +00:00
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
},
],
}
],
stream=True
)
# 广播图片数据
await manager.broadcast_json({"type": "image", "content": image_base64})
# 流式返回结果
async def stream_response():
for chunk in response:
content = chunk.choices[0].delta.content
if content:
yield content
# 同时将内容广播给所有WebSocket连接的客户端
await manager.broadcast_json({"type": "text", "content": content})
return StreamingResponse(stream_response(), media_type="text/plain")
# html页面
2024-10-25 14:53:08 +00:00
@app.get("/terminal", response_class=HTMLResponse)
async def test():
with open("html/terminal.html", "r", encoding="utf-8") as f:
return f.read()
@app.get("/upload", response_class=HTMLResponse)
async def test():
with open("html/upload.html", "r", encoding="utf-8") as f:
return f.read()
@app.get("/", response_class=HTMLResponse)
2024-10-25 12:53:51 +00:00
async def test():
2024-10-25 14:53:08 +00:00
with open("html/index.html", "r", encoding="utf-8") as f:
2024-10-25 12:53:51 +00:00
return f.read()