#!/usr/bin/env python3
"""
Vectorize all records in RawInspection collection
This script adds vectors to all records that don't have them yet
"""
import weaviate
from sentence_transformers import SentenceTransformer
from datetime import datetime
import sys
import time

def vectorize_all_records():
    """Add vectors to all records in RawInspection collection"""
    
    print("="*60)
    print("FULL VECTORIZATION PROCESS")
    print("="*60)
    print(f"Started: {datetime.now()}")
    
    # Connect to Weaviate
    print("\nConnecting to Weaviate...")
    client = weaviate.connect_to_local(host="localhost", port=8080)
    
    try:
        # Check if collection exists
        collections = client.collections.list_all()
        if "RawInspection" not in collections:
            print("❌ RawInspection collection not found!")
            return
            
        collection = client.collections.get("RawInspection")
        
        # Get total count
        response = collection.aggregate.over_all(total_count=True)
        total = response.total_count
        print(f"Total records in collection: {total:,}")
        
        # Load the model
        print("\nLoading Sentence Transformer model...")
        model = SentenceTransformer('all-MiniLM-L6-v2')
        print("✓ Model loaded")
        
        # Process in batches
        batch_size = 100
        offset = 0
        vectorized = 0
        skipped = 0
        errors = 0
        
        print(f"\nProcessing {total:,} records in batches of {batch_size}...")
        print("This will take approximately 30-45 minutes for 295k records")
        print("-" * 60)
        
        start_time = time.time()
        last_update = start_time
        
        while offset < total:
            try:
                # Fetch batch
                batch = collection.query.fetch_objects(
                    limit=batch_size,
                    offset=offset
                )
                
                if not batch.objects:
                    break
                
                # Process each object in batch
                for obj in batch.objects:
                    try:
                        # Check if already has vector
                        if hasattr(obj, 'vector') and obj.vector:
                            skipped += 1
                            continue
                        
                        # Create text for vectorization
                        props = obj.properties
                        text_parts = []
                        
                        # Add key fields if they exist
                        if props.get('dba_name'):
                            text_parts.append(f"Restaurant: {props['dba_name']}")
                        if props.get('facility_type'):
                            text_parts.append(f"Type: {props['facility_type']}")
                        if props.get('results'):
                            text_parts.append(f"Result: {props['results']}")
                        if props.get('violations'):
                            text_parts.append(f"Violations: {props['violations']}")
                        if props.get('address'):
                            text_parts.append(f"Location: {props['address']}")
                        if props.get('city'):
                            text_parts.append(f"City: {props['city']}")
                            
                        # If we have text, create vector
                        if text_parts:
                            text = " | ".join(text_parts)
                            vector = model.encode(text).tolist()
                            
                            # Update the object with vector
                            collection.data.update(
                                uuid=obj.uuid,
                                vector=vector
                            )
                            vectorized += 1
                        else:
                            skipped += 1
                            
                    except Exception as e:
                        errors += 1
                        if errors <= 5:
                            print(f"  Error processing record: {e}")
                
                # Update progress
                offset += batch_size
                current_time = time.time()
                
                # Show progress every 10 seconds or 1000 records
                if current_time - last_update > 10 or vectorized % 1000 == 0:
                    elapsed = current_time - start_time
                    rate = vectorized / elapsed if elapsed > 0 else 0
                    eta = (total - offset) / rate if rate > 0 else 0
                    
                    print(f"Progress: {offset:,}/{total:,} records processed")
                    print(f"  Vectorized: {vectorized:,} | Skipped: {skipped:,} | Errors: {errors:,}")
                    print(f"  Rate: {rate:.1f} records/sec | ETA: {eta/60:.1f} minutes")
                    print("-" * 60)
                    
                    last_update = current_time
                    
            except Exception as e:
                print(f"❌ Batch error at offset {offset}: {e}")
                offset += batch_size
                continue
        
        # Final summary
        elapsed = time.time() - start_time
        print("\n" + "="*60)
        print("VECTORIZATION COMPLETE")
        print("="*60)
        print(f"Total time: {elapsed/60:.1f} minutes")
        print(f"Records vectorized: {vectorized:,}")
        print(f"Records skipped: {skipped:,}")
        print(f"Errors: {errors:,}")
        print(f"Average rate: {vectorized/elapsed:.1f} records/sec")
        
        # Test search to verify
        print("\n" + "="*60)
        print("TESTING VECTORIZED DATA")
        print("="*60)
        
        test_query = "restaurant closed rodent violation"
        print(f"Test query: '{test_query}'")
        query_vector = model.encode(test_query).tolist()
        
        results = collection.query.near_vector(
            near_vector=query_vector,
            limit=5,
            return_properties=["dba_name", "results", "city", "inspection_id"]
        )
        
        if results.objects:
            print("\n✅ Vector search working! Top 5 results:")
            for i, obj in enumerate(results.objects, 1):
                print(f"  {i}. {obj.properties.get('dba_name', 'Unknown')}")
                print(f"     Result: {obj.properties.get('results', 'N/A')}")
                print(f"     City: {obj.properties.get('city', 'N/A')}")
        else:
            print("❌ No results found - something went wrong!")
            
    except Exception as e:
        print(f"❌ Fatal error: {e}")
        import traceback
        traceback.print_exc()
    finally:
        client.close()
        
    print(f"\nCompleted: {datetime.now()}")

if __name__ == "__main__":
    print("This will vectorize all records in the RawInspection collection.")
    print("Estimated time: 30-45 minutes for 295k records")
    response = input("\nProceed? (y/n): ")
    
    if response.lower() == 'y':
        vectorize_all_records()
    else:
        print("Cancelled.")