1.最小化import
2.重构代码结构,独立config类与ConnectionManager类 3.引入toml来管理配置文件 4.引入pydantic的basemodel来简化config类的建构 5.新增信号响应系统,linux上可以响应systemctl reload了 6.修改了部分注释
This commit is contained in:
parent
97eef02a92
commit
1ed9822f2f
29
ConnectionManager.py
Normal file
29
ConnectionManager.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Time : 2024/10/26 下午3:35
|
||||||
|
# @Author : 河瞬
|
||||||
|
# @FileName: ConnectionManager.py
|
||||||
|
# @Software: PyCharm
|
||||||
|
# @Github :
|
||||||
|
from typing import List
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
|
||||||
|
# WebSocket连接管理器
|
||||||
|
class ConnectionManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.active_connections: List[WebSocket] = []
|
||||||
|
|
||||||
|
async def connect(self, websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
self.active_connections.append(websocket)
|
||||||
|
|
||||||
|
def disconnect(self, websocket: WebSocket):
|
||||||
|
self.active_connections.remove(websocket)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def send_personal_message(message: str, websocket: WebSocket):
|
||||||
|
await websocket.send_text(message)
|
||||||
|
|
||||||
|
async def broadcast_json(self, data: dict):
|
||||||
|
for connection in self.active_connections:
|
||||||
|
await connection.send_json(data)
|
||||||
24
config.py
Normal file
24
config.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Time : 2024/10/26 下午3:25
|
||||||
|
# @Author : 河瞬
|
||||||
|
# @FileName: config.py
|
||||||
|
# @Software: PyCharm
|
||||||
|
import toml
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Config(BaseModel):
|
||||||
|
key: str = ""
|
||||||
|
prompt: str = ""
|
||||||
|
model: str = "chatgpt-4o-latest"
|
||||||
|
base_url: str = ""
|
||||||
|
|
||||||
|
def save(self, config_file):
|
||||||
|
with open(config_file, "w", encoding='utf-8') as f:
|
||||||
|
toml.dump(self.model_dump(), f)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, config_file):
|
||||||
|
with open(config_file, "r", encoding='utf-8') as f:
|
||||||
|
config_data = toml.load(f)
|
||||||
|
return cls(**config_data)
|
||||||
75
main.py
75
main.py
@ -1,60 +1,40 @@
|
|||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
|
||||||
from fastapi.responses import StreamingResponse, HTMLResponse
|
from fastapi.responses import StreamingResponse, HTMLResponse
|
||||||
from typing import List
|
from base64 import b64encode
|
||||||
import base64
|
from openai import OpenAI
|
||||||
import openai
|
from os.path import isfile
|
||||||
import signal
|
from sys import platform
|
||||||
|
from config import Config
|
||||||
|
from ConnectionManager import ConnectionManager
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
key = ""
|
|
||||||
prompt = ""
|
|
||||||
|
|
||||||
|
|
||||||
def get_key():
|
def signal_handler():
|
||||||
with open("key", "r") as f:
|
print("Received SIGHUP, reloading config")
|
||||||
k = f.read()
|
global config, client
|
||||||
return k
|
config.load(config_file)
|
||||||
|
client = OpenAI(api_key=config.key, base_url=config.base_url)
|
||||||
|
|
||||||
|
|
||||||
def get_prompt():
|
if platform != 'win32':
|
||||||
with open("prompt", "r", encoding="utf-8") as f:
|
import signal
|
||||||
p = f.read()
|
|
||||||
return p
|
signal.signal(signal.SIGHUP, lambda signum, frame: signal_handler())
|
||||||
|
|
||||||
|
|
||||||
def load_config():
|
def init():
|
||||||
global key, prompt
|
global config
|
||||||
key = get_key()
|
if not isfile(config_file):
|
||||||
prompt = get_prompt()
|
config = Config()
|
||||||
|
config.save(config_file)
|
||||||
|
|
||||||
|
|
||||||
signal.signal(signal.SIGHUP, load_config())
|
config_file = "./config.toml"
|
||||||
|
init()
|
||||||
load_config()
|
config = Config.load(config_file)
|
||||||
client = openai.OpenAI(api_key=key,
|
|
||||||
base_url="https://open.bigmodel.cn/api/paas/v4/")
|
|
||||||
|
|
||||||
|
|
||||||
# WebSocket连接管理器
|
|
||||||
class ConnectionManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.active_connections: List[WebSocket] = []
|
|
||||||
|
|
||||||
async def connect(self, websocket: WebSocket):
|
|
||||||
await websocket.accept()
|
|
||||||
self.active_connections.append(websocket)
|
|
||||||
|
|
||||||
def disconnect(self, websocket: WebSocket):
|
|
||||||
self.active_connections.remove(websocket)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def send_personal_message(message: str, websocket: WebSocket):
|
|
||||||
await websocket.send_text(message)
|
|
||||||
|
|
||||||
async def broadcast_json(self, data: dict):
|
|
||||||
for connection in self.active_connections:
|
|
||||||
await connection.send_json(data)
|
|
||||||
|
|
||||||
|
client = OpenAI(api_key=config.key, base_url=config.base_url)
|
||||||
|
|
||||||
manager = ConnectionManager()
|
manager = ConnectionManager()
|
||||||
|
|
||||||
@ -78,18 +58,18 @@ async def event(websocket: WebSocket):
|
|||||||
async def predict(file: UploadFile = File(...)):
|
async def predict(file: UploadFile = File(...)):
|
||||||
# 读取图片文件并转换为base64编码
|
# 读取图片文件并转换为base64编码
|
||||||
image_data = await file.read()
|
image_data = await file.read()
|
||||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
image_base64 = b64encode(image_data).decode('utf-8')
|
||||||
|
|
||||||
# 构造请求给API
|
# 构造请求给API
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model="glm-4v",
|
model=config.model,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": get_prompt(),
|
"text": config.prompt,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
@ -117,6 +97,7 @@ async def predict(file: UploadFile = File(...)):
|
|||||||
return StreamingResponse(stream_response(), media_type="text/plain")
|
return StreamingResponse(stream_response(), media_type="text/plain")
|
||||||
|
|
||||||
|
|
||||||
|
# html页面
|
||||||
@app.get("/terminal", response_class=HTMLResponse)
|
@app.get("/terminal", response_class=HTMLResponse)
|
||||||
async def test():
|
async def test():
|
||||||
with open("html/terminal.html", "r", encoding="utf-8") as f:
|
with open("html/terminal.html", "r", encoding="utf-8") as f:
|
||||||
|
|||||||
@ -2,4 +2,7 @@ websockets~=13.1
|
|||||||
openai~=1.52.2
|
openai~=1.52.2
|
||||||
fastapi~=0.115.3
|
fastapi~=0.115.3
|
||||||
uvicorn~=0.20.0
|
uvicorn~=0.20.0
|
||||||
python-multipart~=0.0.12
|
python-multipart~=0.0.12
|
||||||
|
requests~=2.27.1
|
||||||
|
toml~=0.10.2
|
||||||
|
pydantic~=2.9.2
|
||||||
Loading…
Reference in New Issue
Block a user