from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from base64 import b64encode from openai import OpenAI from os.path import isfile from sys import platform from config import Config from ConnectionManager import ConnectionManager app = FastAPI() def signal_handler(): print("Received SIGHUP, reloading config") global config, client config.load(config_file) client = OpenAI(api_key=config.key, base_url=config.base_url) if platform != 'win32': import signal signal.signal(signal.SIGHUP, lambda signum, frame: signal_handler()) def init(): global config if not isfile(config_file): config = Config() config.save(config_file) config_file = "./config.toml" init() config = Config.load(config_file) client = OpenAI(api_key=config.key, base_url=config.base_url) manager = ConnectionManager() # WebSocket端点 @app.websocket("/listener") async def event(websocket: WebSocket): await manager.connect(websocket) print("Client connected") try: while True: # 这里可以添加逻辑处理来自客户端的消息 data = await websocket.receive_text() print(f"Received message from client: {data}") except WebSocketDisconnect: manager.disconnect(websocket) # POST请求端点 @app.post("/predict") async def predict(file: UploadFile = File(...)): # 读取图片文件并转换为base64编码 image_data = await file.read() image_base64 = b64encode(image_data).decode('utf-8') # 构造请求给API response = client.chat.completions.create( model=config.model, messages=[ { "role": "user", "content": [ { "type": "text", "text": config.prompt, }, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{image_base64}" }, }, ], } ], stream=True ) # 广播图片数据 await manager.broadcast_json({"type": "image", "content": image_base64}) # 流式返回结果 async def stream_response(): for chunk in response: content = chunk.choices[0].delta.content if content: yield content # 同时将内容广播给所有WebSocket连接的客户端 await manager.broadcast_json({"type": "text", "content": content}) return StreamingResponse(stream_response(), media_type="text/plain") # html页面 @app.get("/terminal", response_class=HTMLResponse) async def test(): with open("html/terminal.html", "r", encoding="utf-8") as f: return f.read() @app.get("/upload", response_class=HTMLResponse) async def test(): with open("html/upload.html", "r", encoding="utf-8") as f: return f.read() @app.get("/", response_class=HTMLResponse) async def test(): with open("html/index.html", "r", encoding="utf-8") as f: return f.read() @app.get("/favicon.ico", response_class=FileResponse) async def favicon(): return FileResponse("html/favicon.ico")