136 lines
3.6 KiB
Python
136 lines
3.6 KiB
Python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
|
|
from fastapi.responses import StreamingResponse, HTMLResponse
|
|
from typing import List
|
|
import base64
|
|
import openai
|
|
import signal
|
|
|
|
app = FastAPI()
|
|
key = ""
|
|
prompt = ""
|
|
|
|
|
|
def get_key():
|
|
with open("key", "r") as f:
|
|
k = f.read()
|
|
return k
|
|
|
|
|
|
def get_prompt():
|
|
with open("prompt", "r", encoding="utf-8") as f:
|
|
p = f.read()
|
|
return p
|
|
|
|
|
|
def load_config():
|
|
global key, prompt
|
|
key = get_key()
|
|
prompt = get_prompt()
|
|
|
|
|
|
signal.signal(signal.SIGHUP, load_config())
|
|
|
|
load_config()
|
|
client = openai.OpenAI(api_key=key,
|
|
base_url="https://open.bigmodel.cn/api/paas/v4/")
|
|
|
|
|
|
# WebSocket连接管理器
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
self.active_connections: List[WebSocket] = []
|
|
|
|
async def connect(self, websocket: WebSocket):
|
|
await websocket.accept()
|
|
self.active_connections.append(websocket)
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
self.active_connections.remove(websocket)
|
|
|
|
@staticmethod
|
|
async def send_personal_message(message: str, websocket: WebSocket):
|
|
await websocket.send_text(message)
|
|
|
|
async def broadcast_json(self, data: dict):
|
|
for connection in self.active_connections:
|
|
await connection.send_json(data)
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
# WebSocket端点
|
|
@app.websocket("/listener")
|
|
async def event(websocket: WebSocket):
|
|
await manager.connect(websocket)
|
|
print("Client connected")
|
|
try:
|
|
while True:
|
|
# 这里可以添加逻辑处理来自客户端的消息
|
|
data = await websocket.receive_text()
|
|
print(f"Received message from client: {data}")
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(websocket)
|
|
|
|
|
|
# POST请求端点
|
|
@app.post("/predict")
|
|
async def predict(file: UploadFile = File(...)):
|
|
# 读取图片文件并转换为base64编码
|
|
image_data = await file.read()
|
|
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
|
|
|
# 构造请求给API
|
|
response = client.chat.completions.create(
|
|
model="glm-4v",
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": get_prompt(),
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
},
|
|
},
|
|
],
|
|
}
|
|
],
|
|
stream=True
|
|
)
|
|
# 广播图片数据
|
|
await manager.broadcast_json({"type": "image", "content": image_base64})
|
|
|
|
# 流式返回结果
|
|
async def stream_response():
|
|
for chunk in response:
|
|
content = chunk.choices[0].delta.content
|
|
if content:
|
|
yield content
|
|
# 同时将内容广播给所有WebSocket连接的客户端
|
|
await manager.broadcast_json({"type": "text", "content": content})
|
|
|
|
return StreamingResponse(stream_response(), media_type="text/plain")
|
|
|
|
|
|
@app.get("/terminal", response_class=HTMLResponse)
|
|
async def test():
|
|
with open("html/terminal.html", "r", encoding="utf-8") as f:
|
|
return f.read()
|
|
|
|
|
|
@app.get("/upload", response_class=HTMLResponse)
|
|
async def test():
|
|
with open("html/upload.html", "r", encoding="utf-8") as f:
|
|
return f.read()
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def test():
|
|
with open("html/index.html", "r", encoding="utf-8") as f:
|
|
return f.read()
|