#!/usr/bin/env python3
"""
PRODUCTION SCRIPT: Clean import with duplicate prevention
1. Clear all existing data
2. Upload to temp collection WITHOUT vectors
3. Check for duplicates
4. Add vectors after verification
5. Move to production
"""
import csv
import os
import sys
from datetime import datetime
import weaviate
import weaviate.classes.config as wvc
from sentence_transformers import SentenceTransformer
import json
from tqdm import tqdm

# Production paths
PROD_DIR = "/var/www/twin-digital-media/public_html/_sites/cleankitchens/production"
LOG_FILE = os.path.join(PROD_DIR, "logs", "clean_import.log")
DATA_DIR = os.path.join(PROD_DIR, "data")
CSV_PATH = "/home/chris/cleankitchens/data/chicago_historical.csv"

def log(msg):
    """Log to file and console"""
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    full_msg = f"[{timestamp}] {msg}"
    print(full_msg)
    with open(LOG_FILE, 'a') as f:
        f.write(full_msg + "\n")

def clean_database(client):
    """Remove ALL collections to start fresh"""
    log("CLEANING DATABASE...")
    collections = client.collections.list_all()
    
    for collection_name in collections:
        log(f"  Deleting collection: {collection_name}")
        client.collections.delete(collection_name)
    
    log("✅ Database cleaned")

def create_temp_collection(client):
    """Create temp collection for initial upload"""
    log("Creating TempImport collection...")
    
    client.collections.create(
        name="TempImport",
        description="Temporary collection for duplicate checking",
        properties=[
            wvc.Property(name="inspection_id", data_type=wvc.DataType.TEXT),
            wvc.Property(name="raw_data", data_type=wvc.DataType.TEXT),
            wvc.Property(name="dba_name", data_type=wvc.DataType.TEXT),
            wvc.Property(name="inspection_date", data_type=wvc.DataType.TEXT),
            wvc.Property(name="city", data_type=wvc.DataType.TEXT),
            wvc.Property(name="results", data_type=wvc.DataType.TEXT),
        ]
    )
    log("✅ Temp collection created")

def upload_to_temp(client):
    """Upload CSV to temp collection WITHOUT vectors"""
    collection = client.collections.get("TempImport")
    
    # Count total
    with open(CSV_PATH, 'r', encoding='utf-8', errors='ignore') as f:
        total_rows = sum(1 for line in f) - 1
    
    log(f"Uploading {total_rows:,} records to temp collection...")
    
    uploaded = 0
    duplicates = 0
    seen_ids = set()
    
    with open(CSV_PATH, 'r', encoding='utf-8', errors='ignore') as csvfile:
        reader = csv.DictReader(csvfile)
        
        batch = []
        batch_size = 100
        
        with tqdm(total=total_rows, desc="Uploading") as pbar:
            for row in reader:
                inspection_id = row.get('Inspection ID', '')
                
                # Check for duplicate
                if inspection_id in seen_ids:
                    duplicates += 1
                    pbar.update(1)
                    continue
                
                seen_ids.add(inspection_id)
                
                # Create key:value string
                raw_data = ", ".join([f"{k}:{v}" for k,v in row.items() if v and v.strip()])
                
                obj = {
                    "inspection_id": inspection_id,
                    "raw_data": raw_data,
                    "dba_name": row.get('DBA Name', ''),
                    "inspection_date": row.get('Inspection Date', ''),
                    "city": row.get('City', ''),
                    "results": row.get('Results', ''),
                }
                
                batch.append(obj)
                
                # Insert batch
                if len(batch) >= batch_size:
                    for data_obj in batch:
                        collection.data.insert(properties=data_obj)
                    uploaded += len(batch)
                    batch = []
                    pbar.update(batch_size)
            
            # Insert remaining
            if batch:
                for data_obj in batch:
                    collection.data.insert(properties=data_obj)
                uploaded += len(batch)
                pbar.update(len(batch))
    
    log(f"✅ Uploaded {uploaded:,} unique records")
    log(f"⚠️  Skipped {duplicates:,} duplicates")
    
    # Save stats
    stats = {
        "uploaded": uploaded,
        "duplicates": duplicates,
        "timestamp": datetime.now().isoformat()
    }
    with open(os.path.join(DATA_DIR, "import_stats.json"), 'w') as f:
        json.dump(stats, f, indent=2)
    
    return uploaded

def verify_no_duplicates(client):
    """Verify no duplicate inspection IDs"""
    log("Verifying no duplicates...")
    
    collection = client.collections.get("TempImport")
    
    # Get count
    response = collection.aggregate.over_all(total_count=True)
    total = response.total_count
    
    log(f"  Total records in temp: {total:,}")
    log("✅ No duplicates (prevented during upload)")
    return total

def add_vectors_to_production(client):
    """Create production collection and add vectors in batches"""
    log("Creating production RawInspection collection...")
    
    # Create production collection
    client.collections.create(
        name="RawInspection",
        description="Production collection with deduplicated data and vectors",
        properties=[
            wvc.Property(name="inspection_id", data_type=wvc.DataType.TEXT),
            wvc.Property(name="raw_data", data_type=wvc.DataType.TEXT),
            wvc.Property(name="dba_name", data_type=wvc.DataType.TEXT),
            wvc.Property(name="inspection_date", data_type=wvc.DataType.TEXT),
            wvc.Property(name="city", data_type=wvc.DataType.TEXT),
            wvc.Property(name="results", data_type=wvc.DataType.TEXT),
            wvc.Property(name="source_api", data_type=wvc.DataType.TEXT),
            wvc.Property(name="status", data_type=wvc.DataType.TEXT),
        ]
    )
    log("✅ Production collection created")
    
    # Load model
    log("Loading Sentence Transformer model...")
    model = SentenceTransformer('all-MiniLM-L6-v2')
    
    temp_collection = client.collections.get("TempImport")
    prod_collection = client.collections.get("RawInspection")
    
    # Get all from temp
    log("Fetching records from temp...")
    results = temp_collection.query.fetch_objects(limit=None)
    
    total = len(results.objects)
    log(f"Processing {total:,} records with vectors...")
    
    processed = 0
    batch_size = 50
    
    with tqdm(total=total, desc="Vectorizing & Moving") as pbar:
        for obj in results.objects:
            # Generate vector from raw_data
            raw_data = obj.properties.get('raw_data', '')
            vector = model.encode(raw_data).tolist()
            
            # Add production fields
            props = obj.properties
            props['source_api'] = 'chicago'
            props['status'] = 'unprocessed'
            
            # Insert with vector
            prod_collection.data.insert(
                properties=props,
                vector=vector
            )
            
            processed += 1
            pbar.update(1)
            
            if processed % 5000 == 0:
                log(f"  Processed {processed:,}/{total:,} records...")
    
    log(f"✅ Moved {processed:,} records to production with vectors")
    
    # Clean up temp
    log("Deleting temp collection...")
    client.collections.delete("TempImport")
    log("✅ Temp collection deleted")

def test_search(client):
    """Test semantic search on production data"""
    log("Testing semantic search...")
    
    model = SentenceTransformer('all-MiniLM-L6-v2')
    collection = client.collections.get("RawInspection")
    
    test_queries = [
        "rodent infestation pizza",
        "temperature violation",
        "failed health inspection"
    ]
    
    for query in test_queries:
        log(f"\n  Query: '{query}'")
        query_vector = model.encode(query).tolist()
        
        results = collection.query.near_vector(
            near_vector=query_vector,
            limit=2,
            return_properties=["dba_name", "results", "inspection_date"]
        )
        
        for i, obj in enumerate(results.objects, 1):
            log(f"    {i}. {obj.properties.get('dba_name', 'Unknown')} - {obj.properties.get('results', 'N/A')}")

def main():
    log("="*60)
    log("PRODUCTION CLEAN IMPORT STARTING")
    log("="*60)
    
    # Connect
    log("Connecting to Weaviate...")
    client = weaviate.connect_to_local(host="localhost", port=8080)
    
    try:
        # Step 1: Clean everything
        clean_database(client)
        
        # Step 2: Create temp collection
        create_temp_collection(client)
        
        # Step 3: Upload without vectors
        count = upload_to_temp(client)
        
        # Step 4: Verify no duplicates
        verify_no_duplicates(client)
        
        # Step 5: Create production and add vectors
        add_vectors_to_production(client)
        
        # Step 6: Test search
        test_search(client)
        
        log("="*60)
        log("✅ PRODUCTION IMPORT COMPLETE!")
        log(f"✅ {count:,} unique records imported with vectors")
        log("✅ NO DUPLICATES!")
        log("✅ Cost: $0.00 (free local vectorization)")
        log("="*60)
        
    except Exception as e:
        log(f"❌ ERROR: {e}")
        import traceback
        log(traceback.format_exc())
        raise
    finally:
        client.close()

if __name__ == "__main__":
    main()