streamlit_ui#
A sample User Interface powered by Streamlit implementing a WebSocket client that connects to the bot WebSocket server.
1import base64
2import json
3import queue
4import sys
5import threading
6import time
7
8import pandas as pd
9import plotly
10import streamlit as st
11import websocket
12from audio_recorder_streamlit import audio_recorder
13from streamlit.runtime import Runtime
14from streamlit.runtime.app_session import AppSession
15from streamlit.runtime.scriptrunner import add_script_run_ctx, get_script_run_ctx
16from streamlit.web import cli as stcli
17
18from besser.bot.core.file import File
19from besser.bot.platforms.payload import Payload, PayloadAction, PayloadEncoder
20from besser.bot.platforms.websocket.message import Message
21
22# Time interval to check if a streamlit session is still active, in seconds
23SESSION_MONITORING_INTERVAL = 10
24
25
26def get_streamlit_session() -> AppSession or None:
27 session_id = get_script_run_ctx().session_id
28 runtime: Runtime = Runtime.instance()
29 return next((
30 s.session
31 for s in runtime._session_mgr.list_sessions()
32 if s.session.id == session_id
33 ), None)
34
35
36def session_monitoring(interval: int):
37 runtime: Runtime = Runtime.instance()
38 session = get_streamlit_session()
39 while True:
40 time.sleep(interval)
41 if not runtime.is_active_session(session.id):
42 runtime.close_session(session.id)
43 session.session_state['websocket'].close()
44 break
45
46
47def main():
48 try:
49 # We get the websocket host and port from the script arguments
50 bot_name = sys.argv[1]
51 except Exception as e:
52 # If they are not provided, we use default values
53 bot_name = 'Chatbot Demo'
54 st.header(bot_name)
55 st.markdown("[Github](https://github.com/BESSER-PEARL/BESSER-Bot-Framework)")
56 # User input component. Must be declared before history writing
57 user_input = st.chat_input("What is up?")
58
59 def on_message(ws, payload_str):
60 # https://github.com/streamlit/streamlit/issues/2838
61 streamlit_session = get_streamlit_session()
62 payload: Payload = Payload.decode(payload_str)
63 if payload.action == PayloadAction.BOT_REPLY_STR.value:
64 content = payload.message
65 t = 'str'
66 elif payload.action == PayloadAction.BOT_REPLY_FILE.value:
67 content = payload.message
68 t = 'file'
69 elif payload.action == PayloadAction.BOT_REPLY_DF.value:
70 content = pd.read_json(payload.message)
71 t = 'dataframe'
72 elif payload.action == PayloadAction.BOT_REPLY_PLOTLY.value:
73 content = plotly.io.from_json(payload.message)
74 t = 'plotly'
75 elif payload.action == PayloadAction.BOT_REPLY_LOCATION.value:
76 content = {
77 'latitude': [payload.message['latitude']],
78 'longitude': [payload.message['longitude']]
79 }
80 t = 'location'
81 elif payload.action == PayloadAction.BOT_REPLY_OPTIONS.value:
82 t = 'options'
83 d = json.loads(payload.message)
84 content = []
85 for button in d.values():
86 content.append(button)
87 message = Message(t, content, is_user=False)
88 streamlit_session._session_state['queue'].put(message)
89 streamlit_session._handle_rerun_script_request()
90
91 def on_error(ws, error):
92 pass
93
94 def on_open(ws):
95 pass
96
97 def on_close(ws, close_status_code, close_msg):
98 pass
99
100 def on_ping(ws, data):
101 pass
102
103 def on_pong(ws, data):
104 pass
105
106 user_type = {
107 0: 'assistant',
108 1: 'user'
109 }
110
111 if 'history' not in st.session_state:
112 st.session_state['history'] = []
113
114 if 'queue' not in st.session_state:
115 st.session_state['queue'] = queue.Queue()
116
117 if 'websocket' not in st.session_state:
118 try:
119 # We get the websocket host and port from the script arguments
120 host = sys.argv[2]
121 port = sys.argv[3]
122 except Exception as e:
123 # If they are not provided, we use default values
124 host = 'localhost'
125 port = '8765'
126 ws = websocket.WebSocketApp(f"ws://{host}:{port}/",
127 on_open=on_open,
128 on_message=on_message,
129 on_error=on_error,
130 on_close=on_close,
131 on_ping=on_ping,
132 on_pong=on_pong)
133 websocket_thread = threading.Thread(target=ws.run_forever)
134 add_script_run_ctx(websocket_thread)
135 websocket_thread.start()
136 st.session_state['websocket'] = ws
137
138 if 'session_monitoring' not in st.session_state:
139 session_monitoring_thread = threading.Thread(target=session_monitoring,
140 kwargs={'interval': SESSION_MONITORING_INTERVAL})
141 add_script_run_ctx(session_monitoring_thread)
142 session_monitoring_thread.start()
143 st.session_state['session_monitoring'] = session_monitoring_thread
144
145 ws = st.session_state['websocket']
146
147 with st.sidebar:
148
149 if reset_button := st.button(label="Reset bot"):
150 st.session_state['history'] = []
151 st.session_state['queue'] = queue.Queue()
152 payload = Payload(action=PayloadAction.RESET)
153 ws.send(json.dumps(payload, cls=PayloadEncoder))
154
155 if voice_bytes := audio_recorder(text=None, pause_threshold=2):
156 if 'last_voice_message' not in st.session_state or st.session_state['last_voice_message'] != voice_bytes:
157 st.session_state['last_voice_message'] = voice_bytes
158 # Encode the audio bytes to a base64 string
159 voice_message = Message(t='audio', content=voice_bytes, is_user=True)
160 st.session_state.history.append(voice_message)
161 voice_base64 = base64.b64encode(voice_bytes).decode('utf-8')
162 payload = Payload(action=PayloadAction.USER_VOICE, message=voice_base64)
163 try:
164 ws.send(json.dumps(payload, cls=PayloadEncoder))
165 except Exception as e:
166 st.error('Your message could not be sent. The connection is already closed')
167 if uploaded_file := st.file_uploader("Choose a file", accept_multiple_files=False):
168 if 'last_file' not in st.session_state or st.session_state['last_file'] != uploaded_file:
169 st.session_state['last_file'] = uploaded_file
170 bytes_data = uploaded_file.read()
171 file_object = File(file_base64=base64.b64encode(bytes_data).decode('utf-8'), file_name=uploaded_file.name, file_type=uploaded_file.type)
172 payload = Payload(action=PayloadAction.USER_FILE, message=file_object.get_json_string())
173 file_message = Message(t='file', content=file_object.to_dict(), is_user=True)
174 st.session_state.history.append(file_message)
175 try:
176 ws.send(json.dumps(payload, cls=PayloadEncoder))
177 except Exception as e:
178 st.error('Your message could not be sent. The connection is already closed')
179 for message in st.session_state['history']:
180 with st.chat_message(user_type[message.is_user]):
181 if message.type == 'audio':
182 st.audio(message.content, format="audio/wav")
183 elif message.type == 'file':
184 file: File = File.from_dict(message.content)
185 file_name = file.name
186 file_type = file.type
187 file_data = base64.b64decode(file.base64.encode('utf-8'))
188 st.download_button(label='Download ' + file_name, file_name=file_name, data=file_data, mime=file_type,
189 key=file_name + str(time.time()))
190 elif message.type == 'location':
191 st.map(message.content)
192 else:
193 st.write(message.content)
194
195 first_message = True
196 while not st.session_state['queue'].empty():
197 message = st.session_state['queue'].get()
198 if hasattr(message, '__len__'):
199 t = len(message.content) / 1000 * 3
200 else:
201 t = 2
202 if t > 3:
203 t = 3
204 elif t < 1 and first_message:
205 t = 1
206 first_message = False
207 if message.type == 'options':
208 st.session_state['buttons'] = message.content
209 elif message.type == 'file':
210 st.session_state['history'].append(message)
211 with st.chat_message('assistant'):
212 with st.spinner(''):
213 time.sleep(t)
214 file: File = File.from_dict(message.content)
215 file_name = file.name
216 file_type = file.type
217 file_data = base64.b64decode(file.base64.encode('utf-8'))
218 st.download_button(label='Download ' + file_name, file_name=file_name, data=file_data, mime=file_type,
219 key=file_name + str(time.time()))
220 elif message.type == 'location':
221 st.session_state['history'].append(message)
222 st.map(message.content)
223 else:
224 st.session_state['history'].append(message)
225 with st.chat_message("assistant"):
226 with st.spinner(''):
227 time.sleep(t)
228 st.write(message.content)
229
230 if 'buttons' in st.session_state:
231 buttons = st.session_state['buttons']
232 cols = st.columns(1)
233 for i, option in enumerate(buttons):
234 if cols[0].button(option):
235 with st.chat_message("user"):
236 st.write(option)
237 message = Message(t='str', content=option, is_user=True)
238 st.session_state.history.append(message)
239 payload = Payload(action=PayloadAction.USER_MESSAGE,
240 message=option)
241 ws.send(json.dumps(payload, cls=PayloadEncoder))
242 del st.session_state['buttons']
243 break
244
245 if user_input:
246 if 'buttons' in st.session_state:
247 del st.session_state['buttons']
248 with st.chat_message("user"):
249 st.write(user_input)
250 message = Message(t='str', content=user_input, is_user=True)
251 st.session_state.history.append(message)
252 payload = Payload(action=PayloadAction.USER_MESSAGE,
253 message=user_input)
254 try:
255 ws.send(json.dumps(payload, cls=PayloadEncoder))
256 except Exception as e:
257 st.error('Your message could not be sent. The connection is already closed')
258
259 st.stop()
260
261
262if __name__ == "__main__":
263 if st.runtime.exists():
264 main()
265 else:
266 sys.argv = ["streamlit", "run", sys.argv[0]]
267 sys.exit(stcli.main())