| import torch
|
| import torch.nn.functional as F
|
| import os
|
| import sys
|
|
|
|
|
|
|
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
|
|
|
|
|
| from src.tokenizer import generate_v1_data, CharacterTokenizer
|
| from src.model import TinyLLM, n_embed, n_head, n_layer, dropout
|
|
|
|
|
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
| WEIGHTS_PATH = 'data/tinyllm_v1_weights1.pt'
|
|
|
|
|
| @torch.no_grad()
|
| def generate(model, idx, max_new_tokens):
|
| """
|
| Takes a sequence of indices (idx) and generates max_new_tokens new indices
|
| using the model autoregressively.
|
| """
|
| model.eval()
|
|
|
|
|
| for _ in range(max_new_tokens):
|
|
|
| block_size = model.block_size
|
| idx_cond = idx[:, -block_size:]
|
|
|
|
|
| logits, _ = model(idx_cond)
|
|
|
|
|
| logits = logits[:, -1, :]
|
|
|
|
|
| probs = F.softmax(logits, dim=-1)
|
|
|
|
|
| idx_next = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
| idx = torch.cat((idx, idx_next), dim=1)
|
|
|
| return idx
|
|
|
|
|
| def setup_inference():
|
| """Sets up the model, tokenizer, and loads weights for inference."""
|
| try:
|
|
|
| raw_data = generate_v1_data()
|
| tokenizer = CharacterTokenizer(raw_data)
|
| max_len = max(len(s) for s in raw_data)
|
|
|
|
|
|
|
| block_size = max_len
|
|
|
|
|
| model = TinyLLM(
|
| vocab_size=tokenizer.vocab_size,
|
| n_embed=n_embed,
|
| n_head=n_head,
|
| n_layer=n_layer,
|
| block_size=block_size,
|
| dropout=dropout
|
| ).to(DEVICE)
|
|
|
|
|
| model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE))
|
| print(f"\nSuccessfully loaded model weights from {WEIGHTS_PATH}")
|
|
|
| return model, tokenizer, block_size
|
|
|
| except FileNotFoundError:
|
| print(f"Error: Weights file not found at {WEIGHTS_PATH}. Please run train.py first.")
|
| return None, None, None
|
| except RuntimeError as e:
|
| print(f"Runtime Error during loading: {e}")
|
| print("Please ensure your src/model.py hyperparameters match the saved weights.")
|
| return None, None, None
|
|
|
|
|
| def solve_problem(model, tokenizer, question_str, block_size):
|
| """Encodes a question, generates the answer, and prints the result."""
|
|
|
|
|
| context_tokens = tokenizer.encode(question_str)
|
|
|
| context_tokens.append(tokenizer.encode(' ')[0])
|
|
|
|
|
| idx = torch.tensor([context_tokens], dtype=torch.long, device=DEVICE)
|
|
|
|
|
|
|
| max_new_tokens = block_size - idx.shape[1]
|
|
|
| if max_new_tokens <= 0:
|
| print("Error: Input sequence is too long.")
|
| return
|
|
|
|
|
| generated_idx = generate(model, idx, max_new_tokens=max_new_tokens)
|
|
|
|
|
| generated_sequence = tokenizer.decode(generated_idx[0].tolist())
|
|
|
| print(f"Question: '{question_str}'")
|
| print(f"Model Output: '{generated_sequence}'")
|
|
|
|
|
|
|
| if __name__ == '__main__':
|
| model, tokenizer, block_size = setup_inference()
|
|
|
| if model is not None:
|
| print("\n--- TinyLLM Math Chatbot Initialized ---")
|
| print("Enter a single-digit math problem (e.g., 4 + 5, 8 / 2).")
|
| print("Type 'exit' to quit.")
|
|
|
| while True:
|
|
|
| question_str = input("Input: ")
|
|
|
| if question_str.lower() == 'exit':
|
| break
|
|
|
|
|
| question_str = question_str.strip()
|
| parts = question_str.split()
|
|
|
|
|
| is_valid = (
|
| len(parts) == 3 and
|
| parts[0].isdigit() and len(parts[0]) == 1 and
|
| parts[2].isdigit() and len(parts[2]) == 1 and
|
| parts[1] in ['+', '-', '*', '/']
|
| )
|
|
|
| if not is_valid:
|
| print("Error: Please enter a problem in the format 'N op N' with single-digit operands (e.g., 2 + 3).\n")
|
| continue
|
|
|
|
|
| solve_problem(model, tokenizer, question_str, block_size)
|
| print("-" * 30)
|
|
|
| print("\n--- Chatbot Shutting Down ---") |