streamlit_ui#

A sample User Interface powered by Streamlit implementing a WebSocket client that connects to the bot WebSocket server.

  1import base64
  2import json
  3import queue
  4import sys
  5import threading
  6import time
  7
  8import pandas as pd
  9import plotly
 10import streamlit as st
 11import websocket
 12from audio_recorder_streamlit import audio_recorder
 13from streamlit.runtime import Runtime
 14from streamlit.runtime.app_session import AppSession
 15from streamlit.runtime.scriptrunner import add_script_run_ctx, get_script_run_ctx
 16from streamlit.web import cli as stcli
 17
 18from besser.bot.core.file import File
 19from besser.bot.platforms.payload import Payload, PayloadAction, PayloadEncoder
 20from besser.bot.platforms.websocket.message import Message
 21
 22# Time interval to check if a streamlit session is still active, in seconds
 23SESSION_MONITORING_INTERVAL = 10
 24
 25
 26def get_streamlit_session() -> AppSession or None:
 27    session_id = get_script_run_ctx().session_id
 28    runtime: Runtime = Runtime.instance()
 29    return next((
 30        s.session
 31        for s in runtime._session_mgr.list_sessions()
 32        if s.session.id == session_id
 33    ), None)
 34
 35
 36def session_monitoring(interval: int):
 37    runtime: Runtime = Runtime.instance()
 38    session = get_streamlit_session()
 39    while True:
 40        time.sleep(interval)
 41        if not runtime.is_active_session(session.id):
 42            runtime.close_session(session.id)
 43            session.session_state['websocket'].close()
 44            break
 45
 46
 47def main():
 48    try:
 49        # We get the websocket host and port from the script arguments
 50        bot_name = sys.argv[1]
 51    except Exception as e:
 52        # If they are not provided, we use default values
 53        bot_name = 'Chatbot Demo'
 54    st.header(bot_name)
 55    st.markdown("[Github](https://github.com/BESSER-PEARL/BESSER-Bot-Framework)")
 56    # User input component. Must be declared before history writing
 57    user_input = st.chat_input("What is up?")
 58
 59    def on_message(ws, payload_str):
 60        # https://github.com/streamlit/streamlit/issues/2838
 61        streamlit_session = get_streamlit_session()
 62        payload: Payload = Payload.decode(payload_str)
 63        if payload.action == PayloadAction.BOT_REPLY_STR.value:
 64            content = payload.message
 65            t = 'str'
 66        elif payload.action == PayloadAction.BOT_REPLY_FILE.value:
 67            content = payload.message
 68            t = 'file'
 69        elif payload.action == PayloadAction.BOT_REPLY_DF.value:
 70            content = pd.read_json(payload.message)
 71            t = 'dataframe'
 72        elif payload.action == PayloadAction.BOT_REPLY_PLOTLY.value:
 73            content = plotly.io.from_json(payload.message)
 74            t = 'plotly'
 75        elif payload.action == PayloadAction.BOT_REPLY_LOCATION.value:
 76            content = {
 77                'latitude': [payload.message['latitude']],
 78                'longitude': [payload.message['longitude']]
 79            }
 80            t = 'location'
 81        elif payload.action == PayloadAction.BOT_REPLY_OPTIONS.value:
 82            t = 'options'
 83            d = json.loads(payload.message)
 84            content = []
 85            for button in d.values():
 86                content.append(button)
 87        message = Message(t, content, is_user=False)
 88        streamlit_session._session_state['queue'].put(message)
 89        streamlit_session._handle_rerun_script_request()
 90
 91    def on_error(ws, error):
 92        pass
 93
 94    def on_open(ws):
 95        pass
 96
 97    def on_close(ws, close_status_code, close_msg):
 98        pass
 99
100    def on_ping(ws, data):
101        pass
102
103    def on_pong(ws, data):
104        pass
105
106    user_type = {
107        0: 'assistant',
108        1: 'user'
109    }
110
111    if 'history' not in st.session_state:
112        st.session_state['history'] = []
113
114    if 'queue' not in st.session_state:
115        st.session_state['queue'] = queue.Queue()
116
117    if 'websocket' not in st.session_state:
118        try:
119            # We get the websocket host and port from the script arguments
120            host = sys.argv[2]
121            port = sys.argv[3]
122        except Exception as e:
123            # If they are not provided, we use default values
124            host = 'localhost'
125            port = '8765'
126        ws = websocket.WebSocketApp(f"ws://{host}:{port}/",
127                                    on_open=on_open,
128                                    on_message=on_message,
129                                    on_error=on_error,
130                                    on_close=on_close,
131                                    on_ping=on_ping,
132                                    on_pong=on_pong)
133        websocket_thread = threading.Thread(target=ws.run_forever)
134        add_script_run_ctx(websocket_thread)
135        websocket_thread.start()
136        st.session_state['websocket'] = ws
137
138    if 'session_monitoring' not in st.session_state:
139        session_monitoring_thread = threading.Thread(target=session_monitoring,
140                                                     kwargs={'interval': SESSION_MONITORING_INTERVAL})
141        add_script_run_ctx(session_monitoring_thread)
142        session_monitoring_thread.start()
143        st.session_state['session_monitoring'] = session_monitoring_thread
144
145    ws = st.session_state['websocket']
146
147    with st.sidebar:
148
149        if reset_button := st.button(label="Reset bot"):
150            st.session_state['history'] = []
151            st.session_state['queue'] = queue.Queue()
152            payload = Payload(action=PayloadAction.RESET)
153            ws.send(json.dumps(payload, cls=PayloadEncoder))
154
155        if voice_bytes := audio_recorder(text=None, pause_threshold=2):
156            if 'last_voice_message' not in st.session_state or st.session_state['last_voice_message'] != voice_bytes:
157                st.session_state['last_voice_message'] = voice_bytes
158                # Encode the audio bytes to a base64 string
159                voice_message = Message(t='audio', content=voice_bytes, is_user=True)
160                st.session_state.history.append(voice_message)
161                voice_base64 = base64.b64encode(voice_bytes).decode('utf-8')
162                payload = Payload(action=PayloadAction.USER_VOICE, message=voice_base64)
163                try:
164                    ws.send(json.dumps(payload, cls=PayloadEncoder))
165                except Exception as e:
166                    st.error('Your message could not be sent. The connection is already closed')
167        if uploaded_file := st.file_uploader("Choose a file", accept_multiple_files=False):
168            if 'last_file' not in st.session_state or st.session_state['last_file'] != uploaded_file:
169                st.session_state['last_file'] = uploaded_file
170                bytes_data = uploaded_file.read()
171                file_object = File(file_base64=base64.b64encode(bytes_data).decode('utf-8'), file_name=uploaded_file.name, file_type=uploaded_file.type)
172                payload = Payload(action=PayloadAction.USER_FILE, message=file_object.get_json_string())
173                file_message = Message(t='file', content=file_object.to_dict(), is_user=True)
174                st.session_state.history.append(file_message)
175                try:
176                    ws.send(json.dumps(payload, cls=PayloadEncoder))
177                except Exception as e:
178                    st.error('Your message could not be sent. The connection is already closed')
179    for message in st.session_state['history']:
180        with st.chat_message(user_type[message.is_user]):
181            if message.type == 'audio':
182                st.audio(message.content, format="audio/wav")
183            elif message.type == 'file':
184                file: File = File.from_dict(message.content)
185                file_name = file.name
186                file_type = file.type
187                file_data = base64.b64decode(file.base64.encode('utf-8'))
188                st.download_button(label='Download ' + file_name, file_name=file_name, data=file_data, mime=file_type,
189                                   key=file_name + str(time.time()))
190            elif message.type == 'location':
191                st.map(message.content)
192            else:
193                st.write(message.content)
194
195    first_message = True
196    while not st.session_state['queue'].empty():
197        message = st.session_state['queue'].get()
198        if hasattr(message, '__len__'):
199            t = len(message.content) / 1000 * 3
200        else:
201            t = 2
202        if t > 3:
203            t = 3
204        elif t < 1 and first_message:
205            t = 1
206        first_message = False
207        if message.type == 'options':
208            st.session_state['buttons'] = message.content
209        elif message.type == 'file':
210            st.session_state['history'].append(message)
211            with st.chat_message('assistant'):
212                with st.spinner(''):
213                    time.sleep(t)
214                file: File = File.from_dict(message.content)
215                file_name = file.name
216                file_type = file.type
217                file_data = base64.b64decode(file.base64.encode('utf-8'))
218                st.download_button(label='Download ' + file_name, file_name=file_name, data=file_data, mime=file_type,
219                                   key=file_name + str(time.time()))
220        elif message.type == 'location':
221            st.session_state['history'].append(message)
222            st.map(message.content)
223        else:
224            st.session_state['history'].append(message)
225            with st.chat_message("assistant"):
226                with st.spinner(''):
227                    time.sleep(t)
228                st.write(message.content)
229
230    if 'buttons' in st.session_state:
231        buttons = st.session_state['buttons']
232        cols = st.columns(1)
233        for i, option in enumerate(buttons):
234            if cols[0].button(option):
235                with st.chat_message("user"):
236                    st.write(option)
237                message = Message(t='str', content=option, is_user=True)
238                st.session_state.history.append(message)
239                payload = Payload(action=PayloadAction.USER_MESSAGE,
240                                  message=option)
241                ws.send(json.dumps(payload, cls=PayloadEncoder))
242                del st.session_state['buttons']
243                break
244
245    if user_input:
246        if 'buttons' in st.session_state:
247            del st.session_state['buttons']
248        with st.chat_message("user"):
249            st.write(user_input)
250        message = Message(t='str', content=user_input, is_user=True)
251        st.session_state.history.append(message)
252        payload = Payload(action=PayloadAction.USER_MESSAGE,
253                          message=user_input)
254        try:
255            ws.send(json.dumps(payload, cls=PayloadEncoder))
256        except Exception as e:
257            st.error('Your message could not be sent. The connection is already closed')
258
259    st.stop()
260
261
262if __name__ == "__main__":
263    if st.runtime.exists():
264        main()
265    else:
266        sys.argv = ["streamlit", "run", sys.argv[0]]
267        sys.exit(stcli.main())