ExerciseOCR/main.py

116 lines
3.4 KiB
Python
Raw Normal View History

2024-10-25 12:53:51 +00:00
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, UploadFile, File
from fastapi.responses import StreamingResponse, HTMLResponse
from pydantic import BaseModel
from typing import List
import base64
import asyncio
import openai
app = FastAPI()
with open("key", "r") as f:
key = f.read()
client = openai.OpenAI(api_key=key,
base_url="https://open.bigmodel.cn/api/paas/v4/"
)
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get("/hello/{name}")
async def say_hello(name: str):
return {"message": f"Hello {name}"}
# WebSocket连接管理器
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
async def broadcast_json(self, data: dict):
for connection in self.active_connections:
await connection.send_json(data)
manager = ConnectionManager()
# WebSocket端点
@app.websocket("/event")
async def event(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
# 这里可以添加逻辑处理来自客户端的消息
data = await websocket.receive_text()
print(f"Received message from client: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
# POST请求端点
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# 读取图片文件并转换为base64编码
image_data = await file.read()
image_base64 = base64.b64encode(image_data).decode('utf-8')
# 构造请求给API
response = client.chat.completions.create(
model="glm-4v",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "请描述该图片",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
},
],
}
],
stream=True
)
# 广播图片数据
await manager.broadcast_json({"type": "image", "content": image_base64})
# 流式返回结果
async def stream_response():
for chunk in response:
content = chunk.choices[0].delta.content
if content:
yield content
# 同时将内容广播给所有WebSocket连接的客户端
# await manager.broadcast(content)
await manager.broadcast_json({"type": "text", "content": content})
return StreamingResponse(stream_response(), media_type="text/plain")
@app.get("/test", response_class=HTMLResponse)
async def test():
with open("test.html", "r", encoding="utf-8") as f:
return f.read()