ExerciseOCR/main.py
2024-10-25 20:53:51 +08:00

116 lines
3.4 KiB
Python

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, UploadFile, File
from fastapi.responses import StreamingResponse, HTMLResponse
from pydantic import BaseModel
from typing import List
import base64
import asyncio
import openai
app = FastAPI()
with open("key", "r") as f:
key = f.read()
client = openai.OpenAI(api_key=key,
base_url="https://open.bigmodel.cn/api/paas/v4/"
)
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get("/hello/{name}")
async def say_hello(name: str):
return {"message": f"Hello {name}"}
# WebSocket连接管理器
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
async def broadcast_json(self, data: dict):
for connection in self.active_connections:
await connection.send_json(data)
manager = ConnectionManager()
# WebSocket端点
@app.websocket("/event")
async def event(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
# 这里可以添加逻辑处理来自客户端的消息
data = await websocket.receive_text()
print(f"Received message from client: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
# POST请求端点
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# 读取图片文件并转换为base64编码
image_data = await file.read()
image_base64 = base64.b64encode(image_data).decode('utf-8')
# 构造请求给API
response = client.chat.completions.create(
model="glm-4v",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "请描述该图片",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
},
],
}
],
stream=True
)
# 广播图片数据
await manager.broadcast_json({"type": "image", "content": image_base64})
# 流式返回结果
async def stream_response():
for chunk in response:
content = chunk.choices[0].delta.content
if content:
yield content
# 同时将内容广播给所有WebSocket连接的客户端
# await manager.broadcast(content)
await manager.broadcast_json({"type": "text", "content": content})
return StreamingResponse(stream_response(), media_type="text/plain")
@app.get("/test", response_class=HTMLResponse)
async def test():
with open("test.html", "r", encoding="utf-8") as f:
return f.read()