import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.cluster import KMeans
from pymongo import MongoClient
from sentence_transformers import SentenceTransformer


def run_kmeans_pipeline():
    print("\U0001F680 Running KMeans pipeline...")

    # Connect to MongoDB
    client = MongoClient("mongodb+srv://2022490404:Z35PZu5dwkleDNcw@cluster0.zrso6e7.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0")
    db = client['fyp']
    collection = db['preprocessData']
    kmeans_collection = db['kmeans']
    skipped_coll = db['kmeans_skipped']

    # Load and filter dataset
    data = list(collection.find())
    df = pd.DataFrame(data)

    # Use all data instead of filtering by attire
    df['main_category_original'] = df['item_category_detail'].apply(
        lambda x: x.split('|')[1].strip().lower() if isinstance(x, str) and '|' in x else 'unknown'
    )

    # Normalize AHP fields
    for col in ['price_actual', 'item_rating', 'total_sold']:
        df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0)

    scaler = MinMaxScaler()
    ahp_cols = ['price_actual', 'item_rating', 'total_sold']
    ahp_norm = scaler.fit_transform(df[ahp_cols])
    df[['price_actual_norm', 'item_rating_norm', 'total_sold_norm']] = ahp_norm

    # Generate embeddings
    model = SentenceTransformer('all-MiniLM-L6-v2')
    df['embedding'] = df['title'].apply(lambda x: model.encode(str(x)))
    print("\U0001F9E0 Embedding complete.")

    results = []
    skipped_categories = []

    for category in df['main_category_original'].unique():
        try:
            df_cat = df[df['main_category_original'] == category].copy()
            print(f"\n\U0001F4C2 Processing category: {category} — {len(df_cat)} items")

            if len(df_cat) < 2:
                skipped_categories.append({"category": category, "reason": "too few items"})
                continue

            embeddings = np.vstack(df_cat['embedding'].values)
            X_scaled = StandardScaler().fit_transform(embeddings)

            k = min(3, len(df_cat))
            if k < 2:
                skipped_categories.append({"category": category, "reason": "not enough for 2 clusters"})
                continue

            kmeans = KMeans(n_clusters=k, random_state=42)
            df_cat['cluster'] = kmeans.fit_predict(X_scaled)

            results.append(df_cat)
        except Exception as e:
            import traceback
            print(f"\n❌ ERROR in category '{category}': {e}")
            traceback.print_exc()
            skipped_categories.append({"category": category, "reason": str(e)})

    # Final merge
    df_clustered = pd.concat(results, ignore_index=True)
    print(f"\n🔄 Total clustered records: {len(df_clustered)}")

    # Make embedding MongoDB safe
    df_clustered['embedding'] = df_clustered['embedding'].apply(lambda x: x.tolist() if hasattr(x, "tolist") else x)

    # Overwrite in MongoDB
    kmeans_collection.delete_many({})
    print("🗑 Cleared 'kmeans' collection.")

    if not df_clustered.empty:
        kmeans_collection.insert_many(df_clustered.to_dict(orient='records'))
        print(f"✅ Inserted {len(df_clustered)} clustered records.")
    else:
        print("⚠️ No records inserted.")

    # Save skipped info
    skipped_coll.delete_many({})
    if skipped_categories:
        skipped_coll.insert_many(skipped_categories)
        print(f"📄 Logged {len(skipped_categories)} skipped categories.")
    else:
        print("✅ No categories skipped.")

    # Sync display fields from products_clean
    print("🔄 Syncing item_rating_display and price_ori_display from products_clean...")
    products_clean_col = db['products_clean']
    metadata = {
        doc['id']: {
            'item_rating_display': doc.get('item_rating_display'),
            'price_ori_display': doc.get('price_ori_display')
        }
        for doc in products_clean_col.find({}, {'id': 1, 'item_rating_display': 1, 'price_ori_display': 1})
    }

    updates = 0
    for doc in kmeans_collection.find():
        prod_id = doc.get('id')
        if prod_id in metadata:
            kmeans_collection.update_one(
                {'_id': doc['_id']},
                {'$set': {
                    'item_rating_display': metadata[prod_id]['item_rating_display'],
                    'price_ori_display': metadata[prod_id]['price_ori_display']
                }}
            )
            updates += 1

    print(f"✅ Synced {updates} documents in 'kmeans' with display fields.")


if __name__ == '__main__':
    run_kmeans_pipeline()
