用FastAPI和Streamlit实现一个ChatBot

前言

本文使用FastAPI+Streamlit实现一个流式响应类ChatGPT的LLM应用,这里只是一个demo,后续会基于此实现一个完整的MCP Client + MCP Server的MCP应用。

Streamlit是专为机器学习和数据科学项目打造的开源Python库,它允许开发者快速创建美观的交互式Web应用,而无需前端开发经验。通过简单的Python脚本,就能构建出功能丰富的数据应用界面。而且官方文档就有ChatBot的示例,直接拿过来稍微修改下就能使用了,上手起来非常简单。

之所以不直接在Streamlit实现MCP Client,是因为MCP SDK的方法几乎都是异步方法,而Streamlit仅支持同步方法,因此我们需要一个中间层来集成MCP Client,而这个中间层我用的就是FastAPI。通过FastAPI实现流式响应,然后在Streamlit再进行流式处理。

用Streamlit实现一个简单的ChatBot

本文大部分基于Streamit的chatbot示例,没做什么改动,只是先尝试下如何使用streamlit。其中pkg.cfg模块提供一些配置信息,后面也都会用到,逻辑比较简单,所以这里就不贴上了,最后的补充部分再贴上。

import streamlit as st from langchain_openai import ChatOpenAI  from pkg.config import cfg   with st.sidebar:     st.button("Clear Chat", on_click=lambda: st.session_state.pop("messages", None), width="stretch")  st.title("MCP Chatbot") st.caption("🚀 A Streamlit chatbot powered by Qwen")  llm = ChatOpenAI(     base_url=cfg.llm_base_url,     model=cfg.llm_model,     api_key=cfg.llm_api_key,     temperature=0.3, )  # Initialize chat history if "messages" not in st.session_state:     st.session_state["messages"] = []  # Display chat messages from history on app rerun for msg in st.session_state.messages:     st.chat_message(msg["role"]).markdown(msg["content"])  # React to user input if prompt := st.chat_input(placeholder="What's up?"):     st.session_state.messages.append({"role": "user", "content": prompt})     # st.chat_message("user").write(prompt)     with st.chat_message("user"):         st.markdown(prompt)      def steam_llm():         for chunk in llm.stream(input=st.session_state.messages):             yield chunk.content      with st.chat_message("assistant"):         msg = st.write_stream(steam_llm())     st.session_state.messages.append({"role": "assistant", "content": msg}) 

在上述代码中,我们创建了一个基础的聊天机器人应用:

  1. 使用Streamlit的侧边栏添加了清除聊天记录的按钮
  2. 初始化了Qwen大语言模型
  3. 实现了聊天历史记录的存储和显示
  4. 通过llm.stream方法实现了流式响应,在用户界面上逐字显示AI回复

FastAPI实现流式响应和客户端流式处理

FastAPI的StreamingResponse实现了流式响应,requests、httpx、aiohttp等http客户端模块都支持流式处理,这里以httpx为例

服务端

from http import HTTPStatus from typing import Sequence  import uvicorn from fastapi import FastAPI from fastapi.responses import (JSONResponse, PlainTextResponse,                                StreamingResponse) from langchain_core.messages import AIMessage, HumanMessage from langchain_openai import ChatOpenAI from pydantic import BaseModel  from pkg.config import cfg from pkg.log import logger  app = FastAPI()  class Message(BaseModel):     role: str     content: str  class UserAsk(BaseModel):     thread_id: str     messages: Sequence[Message]  llm = ChatOpenAI(     model=cfg.llm_model,     api_key=cfg.llm_api_key,     base_url=cfg.llm_base_url,     temperature=0.3, )  async def generate_response(messages: Sequence[Message]):     """一个异步生成器,用于实时生成文本"""     msgs = []     for m in messages:         if m.role in ("human", "user"):             msgs.append(HumanMessage(content=m.content))         elif m.role in ("ai", "assistant"):             msgs.append(AIMessage(content=m.content))         else:             print(f"Unknown role: {m.role}")     async for chunk in llm.astream(msgs):         # Ensure only string is yielded         if hasattr(chunk, "content"):             yield str(chunk.content)         else:             yield str(chunk)          @app.get("/health") async def health():     return PlainTextResponse(content="ok", status_code=HTTPStatus.OK)  @app.post("/stream") async def post_ask_stream(user_ask: UserAsk):     logger.info(f"user_ask: {user_ask}")     if not user_ask.messages:         return JSONResponse(content={"error": "query is empty"}, status_code=HTTPStatus.BAD_REQUEST)     generator = generate_response(user_ask.messages)     return StreamingResponse(generator, media_type="text/event-stream")   if __name__ == "__main__":     uvicorn.run(app, host="127.0.0.1", port=8000) 

在服务端代码中:

  1. 我们定义了两个数据模型:Message,用于处理聊天消息的格式
  2. 使用ChatOpenAI初始化语言模型
  3. generate_response函数是一个异步生成器,将聊天历史转换为LangChain消息格式并流式生成回复
  4. /stream端点接收用户消息并返回流式响应
  5. /health端点用于健康检查

客户端流式处理示例

import asyncio  import httpx   def test_sync_stream():     """同步方式测试流式响应"""     print("=== 同步方式测试流式响应 ===")     with httpx.stream("POST", "http://127.0.0.1:8000/stream",                       json={                          "thread_id": "test_thread_1",                          "messages":[{"role": "user", "content": "你好,请介绍一下你自己"}],                      }) as response:         print("响应状态码:", response.status_code)         for chunk in response.iter_text():             if chunk:                 print(chunk, end='', flush=True)     print("n" + "="*50 + "n")  async def test_async_stream():     """异步方式测试流式响应"""     print("=== 异步方式测试流式响应 ===")     async with httpx.AsyncClient() as client:         async with client.stream("POST", "http://127.0.0.1:8000/stream",                                 json={                                     "thread_id": "test_thread_2",                                     "messages": [{"role": "user", "content": "写一首关于夏天雨天的现代诗"}]                                 }) as response:             print("响应状态码:", response.status_code)             async for chunk in response.aiter_text():                 if chunk:                     print(chunk, end='', flush=True)     print("n" + "="*50 + "n")  def test_health_endpoint():     """测试健康检查端点"""     print("=== 测试健康检查端点 ===")     response = httpx.get("http://127.0.0.1:8000/health")     print("健康检查响应:", response.text)     print("状态码:", response.status_code)     print("="*50 + "n")  if __name__ == "__main__":     test_health_endpoint()          test_sync_stream()          asyncio.run(test_async_stream()) 

客户端示例展示了如何使用httpx处理流式响应:

  1. test_sync_stream函数演示了同步方式处理流式响应
  2. test_async_stream函数演示了异步方式处理流式响应
  3. 两种方式都使用了httpx的流式API逐块处理响应内容

Streamlit集成

import httpx import streamlit as st  def stream_llm(messages: list = []):     with httpx.stream("POST", "http://127.0.0.1:8000/stream", json={"thread_id": "test_thread_1", "messages": messages}) as resp:         for chunk in resp.iter_text():             if chunk:                 yield chunk   with st.sidebar:     st.button("Clear Chat", on_click=lambda: st.session_state.pop("messages", None), width="stretch")  st.title("MCP Chatbot") st.caption("🚀 A Streamlit chatbot powered by Qwen")   # Initialize chat history if "messages" not in st.session_state:     st.session_state["messages"] = []  # Display chat messages from history on app rerun for msg in st.session_state.messages:     st.chat_message(msg["role"]).markdown(msg["content"])  # React to user input if prompt := st.chat_input(placeholder="What's up?"):     st.session_state.messages.append({"role": "user", "content": prompt})     with st.chat_message("user"):         st.markdown(prompt)      with st.chat_message("assistant"):         msg = st.write_stream(stream_llm(st.session_state.messages))     st.session_state.messages.append({"role": "assistant", "content": msg}) 

集成代码中,我们创建了一个函数stream_llm来通过httpx连接FastAPI后端,并将流式响应传递给Streamlit前端显示。这样就实现了从前端到后端的完整流式处理链路。

补充

  • pkg/config.py
import json from pathlib import Path  class Config:     def __init__(self):         p = Path(__file__).parent.parent / "conf" / "config.json"         if not p.exists():             raise FileNotFoundError(f"Config file not found: {p}")         self.data = self.read_json(str(p))      def read_json(self, filepath: str) -> dict:         with open(filepath, "r") as f:             return json.load(f)              @property     def llm_model(self) -> str:         return self.data["llm"]["model"]          @property     def llm_api_key(self):         return self.data["llm"]["api_key"]          @property     def llm_base_url(self) -> str:         return self.data["llm"]["base_url"]          @property     def server_host(self) -> str:         return self.data["server"]["host"]          @property     def server_port(self) -> int:         return self.data["server"]["port"]      cfg = Config() 
  • pkg/log.py
import logging import sys  def set_formatter():     """设置formatter"""     fmt = "%(asctime)s | %(name)s | %(levelname)s | %(filename)s:%(lineno)d | %(funcName)s | %(message)s"     datefmt = "%Y-%m-%d %H:%M:%S"     return logging.Formatter(fmt, datefmt=datefmt)   def set_stream_handler():     return logging.StreamHandler(sys.stdout)  def set_file_handler():     return logging.FileHandler("app.log", mode="a", encoding="utf-8")   def get_logger(name: str = "mylogger", level=logging.DEBUG):     logger = logging.getLogger(name)      formatter = set_formatter()     # handler = set_stream_handler()     handler = set_file_handler()     handler.setFormatter(formatter)     logger.addHandler(handler)      logger.setLevel(level)      return logger   logger = get_logger() 

发表评论

评论已关闭。

相关文章