#!/usr/bin/env python3
"""
Create chicago_temp staging table and populate with CSV data
Stage 1: Bulk insert all rows from CSV
Stage 2: Will consolidate violations later
"""

import weaviate
import weaviate.classes.config as wvc
import csv
import os
from datetime import datetime
from tqdm import tqdm

def create_chicago_temp_table(client):
    """Create the chicago_temp staging table"""
    print("Creating chicago_temp staging table...")
    
    try:
        # Delete if exists
        try:
            client.collections.delete("ChicagoTemp")
            print("  Deleted existing ChicagoTemp collection")
        except:
            pass
        
        # Create new collection with all fields
        client.collections.create(
            name="ChicagoTemp",
            description="Staging table for Chicago food inspection data",
            properties=[
                # Primary fields from CSV
                wvc.Property(name="inspection_id", data_type=wvc.DataType.TEXT),
                wvc.Property(name="facility_name", data_type=wvc.DataType.TEXT),
                wvc.Property(name="aka_name", data_type=wvc.DataType.TEXT),
                wvc.Property(name="license_number", data_type=wvc.DataType.TEXT),
                wvc.Property(name="facility_type", data_type=wvc.DataType.TEXT),
                wvc.Property(name="risk_level", data_type=wvc.DataType.TEXT),
                wvc.Property(name="address", data_type=wvc.DataType.TEXT),
                wvc.Property(name="city", data_type=wvc.DataType.TEXT),
                wvc.Property(name="state", data_type=wvc.DataType.TEXT),
                wvc.Property(name="zip_code", data_type=wvc.DataType.TEXT),
                wvc.Property(name="inspection_date", data_type=wvc.DataType.TEXT),
                wvc.Property(name="inspection_type", data_type=wvc.DataType.TEXT),
                wvc.Property(name="results", data_type=wvc.DataType.TEXT),
                wvc.Property(name="violations", data_type=wvc.DataType.TEXT),
                wvc.Property(name="latitude", data_type=wvc.DataType.TEXT),
                wvc.Property(name="longitude", data_type=wvc.DataType.TEXT),
                
                # Processing fields
                wvc.Property(name="raw_data", data_type=wvc.DataType.TEXT),
                wvc.Property(name="status", data_type=wvc.DataType.TEXT),
                wvc.Property(name="source_city", data_type=wvc.DataType.TEXT),
                wvc.Property(name="is_combined", data_type=wvc.DataType.BOOL),
                wvc.Property(name="combined_from_rows", data_type=wvc.DataType.NUMBER),
                
                # Timestamps
                wvc.Property(name="created_at", data_type=wvc.DataType.TEXT),
                wvc.Property(name="processed_at", data_type=wvc.DataType.TEXT),
                
                # Claude response storage
                wvc.Property(name="claude_response", data_type=wvc.DataType.TEXT),
                wvc.Property(name="error_log", data_type=wvc.DataType.TEXT),
            ]
        )
        
        print("✅ chicago_temp table created successfully")
        return True
        
    except Exception as e:
        print(f"❌ Error creating table: {e}")
        return False

def extract_csv_to_temp(client):
    """Extract all CSV data into chicago_temp table"""
    
    csv_path = "/var/www/twin-digital-media/public_html/_sites/cleankitchens/data/chicago/chicago_food_inspections.csv"
    
    # Count total rows
    print("Counting CSV rows...")
    with open(csv_path, 'r', encoding='utf-8', errors='ignore') as f:
        total_rows = sum(1 for line in f) - 1  # Subtract header
    
    print(f"Found {total_rows:,} records to process")
    
    # Get collection
    collection = client.collections.get("ChicagoTemp")
    
    # Process CSV
    print("Extracting CSV data to chicago_temp...")
    inserted = 0
    errors = 0
    
    with open(csv_path, 'r', encoding='utf-8', errors='ignore') as csvfile:
        reader = csv.DictReader(csvfile)
        
        batch = []
        batch_size = 100
        
        with tqdm(total=total_rows, desc="Inserting") as pbar:
            for row in reader:
                try:
                    # Create raw_data as key:value pairs
                    raw_data_pairs = []
                    for key, value in row.items():
                        if value and value.strip():
                            # Clean the key and value
                            clean_key = key.replace(' ', '_').replace('#', 'Number')
                            clean_value = value.replace(':', ';')  # Replace colons to avoid parsing issues
                            raw_data_pairs.append(f"{clean_key}:{clean_value}")
                    
                    raw_data = ", ".join(raw_data_pairs)
                    
                    # Create record with all extracted fields
                    record = {
                        'inspection_id': row.get('Inspection ID', ''),
                        'facility_name': row.get('DBA Name', ''),
                        'aka_name': row.get('AKA Name', ''),
                        'license_number': row.get('License #', ''),
                        'facility_type': row.get('Facility Type', ''),
                        'risk_level': row.get('Risk', ''),
                        'address': row.get('Address', ''),
                        'city': row.get('City', ''),
                        'state': row.get('State', ''),
                        'zip_code': row.get('Zip', ''),
                        'inspection_date': row.get('Inspection Date', ''),
                        'inspection_type': row.get('Inspection Type', ''),
                        'results': row.get('Results', ''),
                        'violations': row.get('Violations', ''),
                        'latitude': row.get('Latitude', ''),
                        'longitude': row.get('Longitude', ''),
                        'raw_data': raw_data,
                        'status': 'pending',
                        'source_city': 'chicago',
                        'is_combined': False,
                        'combined_from_rows': 0,
                        'created_at': datetime.now().isoformat(),
                        'processed_at': '',
                        'claude_response': '',
                        'error_log': ''
                    }
                    
                    batch.append(record)
                    
                    # Insert batch
                    if len(batch) >= batch_size:
                        for data_obj in batch:
                            try:
                                collection.data.insert(properties=data_obj)
                                inserted += 1
                            except Exception as e:
                                errors += 1
                        batch = []
                        pbar.update(batch_size)
                        
                except Exception as e:
                    errors += 1
                    pbar.update(1)
                    continue
            
            # Insert remaining records
            if batch:
                for data_obj in batch:
                    try:
                        collection.data.insert(properties=data_obj)
                        inserted += 1
                    except Exception as e:
                        errors += 1
                pbar.update(len(batch))
    
    print(f"\n✅ Extraction complete:")
    print(f"   Inserted: {inserted:,} records")
    print(f"   Errors: {errors:,} records")
    
    return inserted

def verify_data(client):
    """Verify data in chicago_temp table"""
    print("\nVerifying data in chicago_temp...")
    
    collection = client.collections.get("ChicagoTemp")
    
    # Get count
    response = collection.aggregate.over_all()
    total_count = response.total_count
    
    print(f"Total records in chicago_temp: {total_count:,}")
    
    # Get sample records to verify structure
    print("\nSample records:")
    sample = collection.query.fetch_objects(limit=3)
    
    for i, obj in enumerate(sample.objects, 1):
        props = obj.properties
        print(f"\nRecord {i}:")
        print(f"  Facility: {props.get('facility_name', 'N/A')}")
        print(f"  Date: {props.get('inspection_date', 'N/A')}")
        print(f"  Result: {props.get('results', 'N/A')}")
        print(f"  Violations: {props.get('violations', 'N/A')[:100]}..." if props.get('violations') else "  Violations: None")
        print(f"  Status: {props.get('status', 'N/A')}")
    
    # Check for facilities with multiple violations on same date
    print("\nChecking for inspections with multiple violations...")
    print("(This will be consolidated in the next step)")
    
    # Query a few records to analyze
    sample_for_analysis = collection.query.fetch_objects(limit=1000)
    
    facility_date_counts = {}
    for obj in sample_for_analysis.objects:
        props = obj.properties
        key = f"{props.get('facility_name', '')}|{props.get('inspection_date', '')}"
        if key in facility_date_counts:
            facility_date_counts[key] += 1
        else:
            facility_date_counts[key] = 1
    
    multiples = {k: v for k, v in facility_date_counts.items() if v > 1}
    
    if multiples:
        print(f"Found {len(multiples)} facilities with multiple violation records (sample of 1000)")
        # Show first 3 examples
        for key, count in list(multiples.items())[:3]:
            facility, date = key.split('|')
            print(f"  - {facility} on {date}: {count} records")
    else:
        print("No facilities with multiple violations found in sample")
    
    return total_count

def main():
    print("="*60)
    print("CHICAGO DATA EXTRACTION TO TEMP TABLE")
    print("="*60)
    
    # Connect to Weaviate
    print("Connecting to Weaviate...")
    client = weaviate.connect_to_local(host="localhost", port=8080)
    
    try:
        # Step 1: Create table
        if not create_chicago_temp_table(client):
            print("Failed to create table, exiting...")
            return
        
        # Step 2: Extract CSV data
        inserted = extract_csv_to_temp(client)
        
        # Step 3: Verify
        total = verify_data(client)
        
        print("\n" + "="*60)
        print("✅ EXTRACTION COMPLETE")
        print(f"   Total records in chicago_temp: {total:,}")
        print("   Ready for consolidation step")
        print("="*60)
        
    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()
    
    finally:
        client.close()

if __name__ == "__main__":
    main()