aaron0eidt commited on
Commit
221a3f0
·
1 Parent(s): 7f9ac5c

Add cached analysis results with LFS

Browse files
.gitattributes CHANGED
@@ -1 +1,2 @@
1
  *.npz filter=lfs diff=lfs merge=lfs -text
 
 
1
  *.npz filter=lfs diff=lfs merge=lfs -text
2
+ cache/*.json filter=lfs diff=lfs merge=lfs -text
cache/cached_attribution_results.json CHANGED
The diff for this file is too large to render. See raw diff
 
cache/cached_function_vector_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:332bae308b3bc0f92d10bc2389f170c88563080c0bc68de469cef8300d8e9bd7
3
+ size 10673942
function_vectors/function_vectors_page.py CHANGED
@@ -520,23 +520,157 @@ class LayerEvolutionAnalyzer:
520
  'layer_changes': layer_changes
521
  }
522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  def run_interactive_analysis(input_text, include_attribution=True, include_evolution=True, enable_ai_explanation=True):
524
  # A wrapper function for running the analysis from the UI.
525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  # Before running, check if models exist if not using a cached value.
527
- # This check relies on the fact that caching is attempted first.
528
  model_path = "./models/OLMo-2-1124-7B"
529
  model_exists = os.path.exists(model_path)
530
- # if not os.path.exists(model_path):
531
- # # We assume if the model path is missing, we are in a static environment.
532
- # # The calling function should have already checked the cache.
533
- # st.info("This live demo is running in a static environment. Only the pre-cached example prompts are available. Please select an example to view its analysis.")
534
- # return None
535
 
536
  current_lang = st.session_state.get('lang', 'en')
537
 
538
  try:
539
  results = _perform_analysis(input_text, include_attribution, include_evolution, current_lang, enable_ai_explanation)
 
 
 
540
  except Exception as e:
541
  if not model_exists:
542
  st.info("This live demo is running in a static environment. Only the pre-cached example prompts are available. Please select an example to view its analysis.")
@@ -1289,6 +1423,9 @@ def display_analysis_results(results, input_text):
1289
  with st.spinner(tr('running_faithfulness_check_spinner')):
1290
  claims = _cached_extract_fv_claims(api_config, st.session_state.explanation_part_1, "pca")
1291
  verification_results = verify_fv_claims(claims, results, "pca")
 
 
 
1292
  else:
1293
  verification_results = []
1294
  st.warning(tr('api_key_not_configured_warning'))
@@ -1365,6 +1502,9 @@ def display_analysis_results(results, input_text):
1365
  with st.spinner(tr('running_faithfulness_check_spinner')):
1366
  claims = _cached_extract_fv_claims(api_config, st.session_state.explanation_part_2, "pca")
1367
  verification_results = verify_fv_claims(claims, results, "pca")
 
 
 
1368
  else:
1369
  verification_results = []
1370
  st.warning(tr('api_key_not_configured_warning'))
@@ -1549,6 +1689,9 @@ def display_analysis_results(results, input_text):
1549
  with st.spinner(tr('running_faithfulness_check_spinner')):
1550
  claims = _cached_extract_fv_claims(api_config, st.session_state.explanation_part_3, "pca")
1551
  verification_results = verify_fv_claims(claims, results, "pca")
 
 
 
1552
  else:
1553
  verification_results = []
1554
  st.warning(tr('api_key_not_configured_warning'))
@@ -1713,6 +1856,9 @@ def display_evolution_results(evolution_results):
1713
  with st.spinner(tr('running_faithfulness_check_spinner')):
1714
  claims = _cached_extract_fv_claims(api_config, st.session_state.evolution_explanation_part_1, "evolution")
1715
  verification_results = verify_fv_claims(claims, st.session_state.analysis_results, "evolution")
 
 
 
1716
  else:
1717
  verification_results = []
1718
  st.warning(tr('api_key_not_configured_warning'))
@@ -1821,6 +1967,9 @@ def display_evolution_results(evolution_results):
1821
  with st.spinner(tr('running_faithfulness_check_spinner')):
1822
  claims = _cached_extract_fv_claims(api_config, st.session_state.evolution_explanation_part_2, "evolution")
1823
  verification_results = verify_fv_claims(claims, st.session_state.analysis_results, "evolution")
 
 
 
1824
  else:
1825
  verification_results = []
1826
  st.warning(tr('api_key_not_configured_warning'))
 
520
  'layer_changes': layer_changes
521
  }
522
 
523
+ def update_fv_cache(input_text, results):
524
+ cache_file = os.path.join("cache", "cached_function_vector_results.json")
525
+ os.makedirs("cache", exist_ok=True)
526
+
527
+ try:
528
+ if os.path.exists(cache_file):
529
+ with open(cache_file, "r", encoding="utf-8") as f:
530
+ cached_data = json.load(f)
531
+ else:
532
+ cached_data = {}
533
+ except:
534
+ cached_data = {}
535
+
536
+ # Recursive serializer to handle numpy types
537
+ def make_serializable(obj):
538
+ if isinstance(obj, np.ndarray):
539
+ return obj.tolist()
540
+ if isinstance(obj, (np.float32, np.float64, np.float16)):
541
+ return float(obj)
542
+ if isinstance(obj, (np.int32, np.int64, np.int16)):
543
+ return int(obj)
544
+ if isinstance(obj, (np.bool_, bool)):
545
+ return bool(obj)
546
+ if isinstance(obj, dict):
547
+ return {k: make_serializable(v) for k, v in obj.items()}
548
+ if isinstance(obj, list):
549
+ return [make_serializable(v) for v in obj]
550
+ return obj
551
+
552
+ serializable_data = {
553
+ 'attribution': {},
554
+ 'evolution': make_serializable(results.get('evolution')),
555
+ 'pca_explanation': results.get('pca_explanation'),
556
+ 'evolution_explanation': results.get('evolution_explanation'),
557
+ 'faithfulness': results.get('faithfulness', {})
558
+ }
559
+
560
+ if 'attribution' in results:
561
+ attr = results['attribution']
562
+ serializable_data['attribution'] = {
563
+ 'input_activation': make_serializable(attr.get('input_activation')),
564
+ 'function_type_scores': make_serializable(attr.get('function_type_scores')),
565
+ 'category_scores': make_serializable(attr.get('category_scores')),
566
+ 'input_text': attr.get('input_text')
567
+ }
568
+
569
+ cached_data[input_text] = serializable_data
570
+
571
+ with open(cache_file, "w", encoding="utf-8") as f:
572
+ json.dump(cached_data, f, ensure_ascii=False, indent=4)
573
+ print(f"Saved FV analysis for '{input_text}' to cache.")
574
+
575
+ def update_fv_cache_with_faithfulness(input_text, key, verification_results):
576
+ cache_file = os.path.join("cache", "cached_function_vector_results.json")
577
+ if not os.path.exists(cache_file): return
578
+
579
+ # Recursive serializer to handle numpy types
580
+ def make_serializable(obj):
581
+ if isinstance(obj, np.ndarray):
582
+ return obj.tolist()
583
+ if isinstance(obj, (np.float32, np.float64, np.float16)):
584
+ return float(obj)
585
+ if isinstance(obj, (np.int32, np.int64, np.int16)):
586
+ return int(obj)
587
+ if isinstance(obj, (np.bool_, bool)):
588
+ return bool(obj)
589
+ if isinstance(obj, dict):
590
+ return {k: make_serializable(v) for k, v in obj.items()}
591
+ if isinstance(obj, list):
592
+ return [make_serializable(v) for v in obj]
593
+ return obj
594
+
595
+ try:
596
+ with open(cache_file, "r", encoding="utf-8") as f:
597
+ cached_data = json.load(f)
598
+
599
+ if input_text in cached_data:
600
+ if "faithfulness" not in cached_data[input_text]:
601
+ cached_data[input_text]["faithfulness"] = {}
602
+
603
+ cached_data[input_text]["faithfulness"][key] = make_serializable(verification_results)
604
+
605
+ with open(cache_file, "w", encoding="utf-8") as f:
606
+ json.dump(cached_data, f, ensure_ascii=False, indent=4)
607
+ print(f"Saved faithfulness for {key} to cache.")
608
+ except Exception as e:
609
+ print(f"Failed to update FV cache with faithfulness: {e}")
610
+
611
  def run_interactive_analysis(input_text, include_attribution=True, include_evolution=True, enable_ai_explanation=True):
612
  # A wrapper function for running the analysis from the UI.
613
 
614
+ # Check cache first
615
+ cache_file = os.path.join("cache", "cached_function_vector_results.json")
616
+ if os.path.exists(cache_file):
617
+ try:
618
+ with open(cache_file, "r", encoding="utf-8") as f:
619
+ cached_data = json.load(f)
620
+ if input_text in cached_data:
621
+ print(f"Loading FV analysis for '{input_text}' from cache.")
622
+ data = cached_data[input_text]
623
+
624
+ results = {
625
+ 'evolution': data.get('evolution'),
626
+ 'pca_explanation': data.get('pca_explanation'),
627
+ 'evolution_explanation': data.get('evolution_explanation'),
628
+ 'faithfulness': data.get('faithfulness')
629
+ }
630
+
631
+ if 'attribution' in data:
632
+ attr_data = data['attribution']
633
+ input_activation = np.array(attr_data['input_activation'])
634
+
635
+ # Load static vectors
636
+ current_lang = st.session_state.get('lang', 'en')
637
+ ft_vectors, cat_vectors, error = _load_precomputed_vectors(current_lang)
638
+
639
+ if not error:
640
+ results['attribution'] = {
641
+ 'input_activation': input_activation,
642
+ 'function_type_scores': attr_data.get('function_type_scores'),
643
+ 'category_scores': attr_data.get('category_scores'),
644
+ 'function_types_mapping': FUNCTION_TYPES,
645
+ 'input_text': input_text,
646
+ 'category_vectors': cat_vectors,
647
+ 'function_type_vectors': ft_vectors
648
+ }
649
+
650
+ st.session_state.user_input_3d_data = results.get('attribution')
651
+
652
+ # Populate faithfulness in analysis_results if needed
653
+ if 'faithfulness' in results and results['faithfulness']:
654
+ if 'analysis_results' not in st.session_state:
655
+ st.session_state.analysis_results = {}
656
+ st.session_state.analysis_results['pca_faithfulness'] = results['faithfulness'].get('pca')
657
+ st.session_state.analysis_results['evolution_faithfulness'] = results['faithfulness'].get('evolution')
658
+
659
+ return results
660
+ except Exception as e:
661
+ print(f"Error loading from cache: {e}")
662
+
663
  # Before running, check if models exist if not using a cached value.
 
664
  model_path = "./models/OLMo-2-1124-7B"
665
  model_exists = os.path.exists(model_path)
 
 
 
 
 
666
 
667
  current_lang = st.session_state.get('lang', 'en')
668
 
669
  try:
670
  results = _perform_analysis(input_text, include_attribution, include_evolution, current_lang, enable_ai_explanation)
671
+ # Save to cache
672
+ update_fv_cache(input_text, results)
673
+
674
  except Exception as e:
675
  if not model_exists:
676
  st.info("This live demo is running in a static environment. Only the pre-cached example prompts are available. Please select an example to view its analysis.")
 
1423
  with st.spinner(tr('running_faithfulness_check_spinner')):
1424
  claims = _cached_extract_fv_claims(api_config, st.session_state.explanation_part_1, "pca")
1425
  verification_results = verify_fv_claims(claims, results, "pca")
1426
+ # Update cache
1427
+ if 'attribution' in results and 'input_text' in results['attribution']:
1428
+ update_fv_cache_with_faithfulness(results['attribution']['input_text'], "pca", verification_results)
1429
  else:
1430
  verification_results = []
1431
  st.warning(tr('api_key_not_configured_warning'))
 
1502
  with st.spinner(tr('running_faithfulness_check_spinner')):
1503
  claims = _cached_extract_fv_claims(api_config, st.session_state.explanation_part_2, "pca")
1504
  verification_results = verify_fv_claims(claims, results, "pca")
1505
+ # Update cache
1506
+ if 'attribution' in results and 'input_text' in results['attribution']:
1507
+ update_fv_cache_with_faithfulness(results['attribution']['input_text'], "pca", verification_results)
1508
  else:
1509
  verification_results = []
1510
  st.warning(tr('api_key_not_configured_warning'))
 
1689
  with st.spinner(tr('running_faithfulness_check_spinner')):
1690
  claims = _cached_extract_fv_claims(api_config, st.session_state.explanation_part_3, "pca")
1691
  verification_results = verify_fv_claims(claims, results, "pca")
1692
+ # Update cache
1693
+ if 'attribution' in results and 'input_text' in results['attribution']:
1694
+ update_fv_cache_with_faithfulness(results['attribution']['input_text'], "pca", verification_results)
1695
  else:
1696
  verification_results = []
1697
  st.warning(tr('api_key_not_configured_warning'))
 
1856
  with st.spinner(tr('running_faithfulness_check_spinner')):
1857
  claims = _cached_extract_fv_claims(api_config, st.session_state.evolution_explanation_part_1, "evolution")
1858
  verification_results = verify_fv_claims(claims, st.session_state.analysis_results, "evolution")
1859
+ # Update cache
1860
+ if 'attribution' in st.session_state.analysis_results and 'input_text' in st.session_state.analysis_results['attribution']:
1861
+ update_fv_cache_with_faithfulness(st.session_state.analysis_results['attribution']['input_text'], "evolution", verification_results)
1862
  else:
1863
  verification_results = []
1864
  st.warning(tr('api_key_not_configured_warning'))
 
1967
  with st.spinner(tr('running_faithfulness_check_spinner')):
1968
  claims = _cached_extract_fv_claims(api_config, st.session_state.evolution_explanation_part_2, "evolution")
1969
  verification_results = verify_fv_claims(claims, st.session_state.analysis_results, "evolution")
1970
+ # Update cache
1971
+ if 'attribution' in st.session_state.analysis_results and 'input_text' in st.session_state.analysis_results['attribution']:
1972
+ update_fv_cache_with_faithfulness(st.session_state.analysis_results['attribution']['input_text'], "evolution", verification_results)
1973
  else:
1974
  verification_results = []
1975
  st.warning(tr('api_key_not_configured_warning'))