# coding=utf-8 import time import torch import uvicorn from pydantic import BaseModel, Field from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from typing import Any, Dict, List, Literal, Optional, Union from transformers import AutoTokenizer, AutoModel from sse_starlette.sse import ServerSentEvent, EventSourceResponse from fastapi import Depends, HTTPException, Request from starlette.status import HTTP_401_UNAUTHORIZED import argparse @asynccontextmanager async def lifespan(app: FastAPI): # collects GPU memory yield if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ChatMessage(BaseModel): role: Literal["user", "assistant", "system"] content: str class DeltaMessage(BaseModel): role: Optional[Literal["user", "assistant", "system"]] = None content: Optional[str] = None class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] temperature: Optional[float] = None top_p: Optional[float] = None max_length: Optional[int] = None stream: Optional[bool] = False class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage finish_reason: Literal["stop", "length"] class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage finish_reason: Optional[Literal["stop", "length"]] class ChatCompletionResponse(BaseModel): model: str object: Literal["chat.completion", "chat.completion.chunk"] choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] created: Optional[int] = Field(default_factory=lambda: int(time.time())) async def verify_token(request: Request): auth_header = request.headers.get('Authorization') if auth_header: token_type, _, token = auth_header.partition(' ') if token_type.lower() == "bearer" and token == "sk-aaabbbcccdddeeefffggghhhiiijjjkkk": # 这里配置你的token return True raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, detail="Invalid authorization credentials", ) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest, token: bool = Depends(verify_token)): global model, tokenizer if request.messages[-1].role != "user": raise HTTPException(status_code=400, detail="Invalid request") query = request.messages[-1].content prev_messages = request.messages[:-1] if len(prev_messages) > 0 and prev_messages[0].role == "system": query = prev_messages.pop(0).content + query history = [] if len(prev_messages) % 2 == 0: for i in range(0, len(prev_messages), 2): if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": history.append([prev_messages[i].content, prev_messages[i+1].content]) if request.stream: generate = predict(query, history, request.model) return EventSourceResponse(generate, media_type="text/event-stream") response, _ = model.chat(tokenizer, query, history=history) choice_data = ChatCompletionResponseChoice( index=0, message=ChatMessage(role="assistant", content=response), finish_reason="stop" ) return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") async def predict(query: str, history: List[List[str]], model_id: str): global model, tokenizer choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role="assistant"), finish_reason=None ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) current_length = 0 for new_response, _ in model.stream_chat(tokenizer, query, history): if len(new_response) == current_length: continue new_text = new_response[current_length:] current_length = len(new_response) choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(content=new_text), finish_reason=None ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(), finish_reason="stop" ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield '[DONE]' if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_name", default="16", type=str, help="Model name") args = parser.parse_args() model_dict = { "4": "THUDM/chatglm2-6b-int4", "8": "THUDM/chatglm2-6b-int8", "16": "THUDM/chatglm2-6b" } model_name = model_dict.get(args.model_name, "THUDM/chatglm2-6b") tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda() model.eval() uvicorn.run(app, host='0.0.0.0', port=6006, workers=1)