| | import random |
| | import time |
| | import logging |
| | from json import JSONDecodeError |
| |
|
| | import streamlit as st |
| |
|
| | from app_utils.backend_utils import load_statements, query |
| | from app_utils.frontend_utils import ( |
| | set_state_if_absent, |
| | reset_results, |
| | entailment_html_messages, |
| | create_df_for_relevant_snippets, |
| | create_ternary_plot, |
| | build_sidebar, |
| | ) |
| | from app_utils.config import RETRIEVER_TOP_K |
| |
|
| |
|
| | def main(): |
| | statements = load_statements() |
| | build_sidebar() |
| |
|
| | |
| | set_state_if_absent("statement", "Referral bonus can only be given if your friend joins Newton School on your behalf") |
| | set_state_if_absent("answer", "") |
| | set_state_if_absent("results", None) |
| | set_state_if_absent("raw_json", None) |
| | set_state_if_absent("random_statement_requested", False) |
| |
|
| | st.write("Referral Mis-Sell") |
| | st.write() |
| | st.markdown( |
| | """ |
| | ##### Enter statement |
| | """ |
| | ) |
| | |
| | statement = st.text_input( |
| | "", value=st.session_state.statement, max_chars=100, on_change=reset_results |
| | ) |
| | col1, col2 = st.columns(2) |
| | col1.markdown( |
| | "<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True |
| | ) |
| | col2.markdown( |
| | "<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True |
| | ) |
| | |
| | run_pressed = col1.button("Run") |
| | |
| | if col2.button("Random statement"): |
| | reset_results() |
| | statement = random.choice(statements) |
| | |
| | while statement == st.session_state.statement: |
| | statement = random.choice(statements) |
| | st.session_state.statement = statement |
| | st.session_state.random_statement_requested = True |
| | |
| | |
| | |
| | if hasattr(st, "scriptrunner"): |
| | raise st.scriptrunner.script_runner.RerunException( |
| | st.scriptrunner.script_requests.RerunData(widget_states=None) |
| | ) |
| | raise st.runtime.scriptrunner.script_runner.RerunException( |
| | st.runtime.scriptrunner.script_requests.RerunData(widget_states=None) |
| | ) |
| | else: |
| | st.session_state.random_statement_requested = False |
| | run_query = ( |
| | run_pressed or statement != st.session_state.statement |
| | ) and not st.session_state.random_statement_requested |
| |
|
| | |
| | if run_query and statement: |
| | time_start = time.time() |
| | reset_results() |
| | st.session_state.statement = statement |
| | with st.spinner("๐ง Running Model..."): |
| | try: |
| | st.session_state.results = query(statement, RETRIEVER_TOP_K) |
| | print(f"S: {statement}") |
| | time_end = time.time() |
| | print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) |
| | print(f"elapsed time: {time_end - time_start}") |
| | except JSONDecodeError as je: |
| | st.error( |
| | "๐ An error occurred reading the results. Is the document store working?" |
| | ) |
| | return |
| | except Exception as e: |
| | logging.exception(e) |
| | st.error("๐ An error occurred during the request.") |
| | return |
| |
|
| | |
| | if st.session_state.results: |
| | docs = st.session_state.results["documents"] |
| | agg_entailment_info = st.session_state.results["aggregate_entailment_info"] |
| |
|
| | |
| | max_key = max(agg_entailment_info, key=agg_entailment_info.get) |
| | message = entailment_html_messages[max_key] |
| | st.markdown(f"<br/><h4>{message}</h4>", unsafe_allow_html=True) |
| |
|
| | st.markdown(f"###### Aggregate entailment information:") |
| | col1, col2 = st.columns([2, 1]) |
| | fig = create_ternary_plot(agg_entailment_info) |
| | with col1: |
| | st.plotly_chart(fig, use_container_width=True) |
| | with col2: |
| | st.write(agg_entailment_info) |
| |
|
| | st.markdown(f"###### Most Relevant snippets:") |
| | df, urls = create_df_for_relevant_snippets(docs) |
| | st.dataframe(df) |
| | str_wiki_pages = "Data: " |
| | for doc, url in urls.items(): |
| | str_wiki_pages += f"[{doc}]({url}) " |
| | st.markdown(str_wiki_pages) |
| |
|
| |
|
| | main() |
| |
|