#!/usr/bin/env python3
"""
Redis → ic3_asset_monitoring_data ingestion worker
Fetches live IoT tag readings from Redis and stores them in PostgreSQL.
"""

import os
import sys
import json
import time
import logging
import requests
import argparse
from datetime import datetime
from urllib.parse import quote
import psycopg2
from psycopg2.extras import execute_values

# ═════════════════════════════════════════════════════════════════════════════
# Configuration
# ═════════════════════════════════════════════════════════════════════════════

PGHOST = os.getenv("PGHOST")
PGPORT = int(os.getenv("PGPORT", "5432"))
PGDATABASE = os.getenv("PGDATABASE")
PGUSER = os.getenv("PGUSER")
PGPASSWORD = os.getenv("PGPASSWORD")
REDIS_API_BASE = os.getenv("REDIS_API_BASE")

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('redis_ingestion.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# ═════════════════════════════════════════════════════════════════════════════
# Database functions
# ═════════════════════════════════════════════════════════════════════════════

def get_db_connection():
    """Create a PostgreSQL connection."""
    return psycopg2.connect(
        host=PGHOST,
        port=PGPORT,
        database=PGDATABASE,
        user=PGUSER,
        password=PGPASSWORD
    )

def ensure_unique_index(conn):
    """Create the unique index on (tag_code, recorded_at) if it doesn't exist."""
    try:
        with conn.cursor() as cur:
            cur.execute("""
                CREATE UNIQUE INDEX IF NOT EXISTS idx_monitoring_tag_time_unique
                ON ic3_asset_monitoring_data (tag_code, recorded_at);
            """)
        conn.commit()
        logger.info("Unique index ensured")
    except Exception as e:
        logger.error(f"Failed to create index: {e}")
        conn.rollback()

def get_asset_id(conn, serialnumber):
    """Look up asset_id in ic3_asset_master by serialnumber."""
    try:
        with conn.cursor() as cur:
            cur.execute(
                "SELECT asset_id FROM ic3_asset_master WHERE serialnumber = %s LIMIT 1",
                (serialnumber,)
            )
            result = cur.fetchone()
            return result[0] if result else None
    except Exception as e:
        logger.error(f"Error querying asset_master: {e}")
        return None

def get_tag_mapping(conn, asset_id, param_tag_code):
    """
    Look up ic3_asset_tag_map for the given asset_id and param_tag_code.
    Returns (tag_code, param_id, unit) or (None, None, None) if not found.
    """
    try:
        with conn.cursor() as cur:
            cur.execute("""
                SELECT atm.tag_code, atm.param_id, pm.unit
                FROM ic3_asset_tag_map atm
                INNER JOIN ic3_parameter_master pm ON pm.param_id = atm.param_id
                WHERE atm.asset_id = %s AND atm.tag_code LIKE %s
                LIMIT 1
            """, (asset_id, f'%.{param_tag_code}'))
            result = cur.fetchone()
            return result if result else (None, None, None)
    except Exception as e:
        logger.error(f"Error querying tag_map: {e}")
        return (None, None, None)

def get_common_param_id(conn, param_tag_code):
    """Get param_id from ic3_parameter_master for common metadata parameters."""
    try:
        with conn.cursor() as cur:
            cur.execute(
                "SELECT param_id FROM ic3_parameter_master WHERE param_tag_code = %s AND asset_group = 'COMMON' LIMIT 1",
                (param_tag_code,)
            )
            result = cur.fetchone()
            return result[0] if result else None
    except Exception as e:
        logger.error(f"Error querying common param: {e}")
        return None

def insert_monitoring_rows(conn, rows):
    """Batch insert rows into ic3_asset_monitoring_data."""
    if not rows:
        return 0
    try:
        with conn.cursor() as cur:
            execute_values(
                cur,
                """
                INSERT INTO ic3_asset_monitoring_data
                (asset_id, tag_code, param_id, value_text, value_num, unit, recorded_at)
                VALUES %s
                ON CONFLICT (tag_code, recorded_at) DO NOTHING
                """,
                rows,
                page_size=1000
            )
        conn.commit()
        logger.info(f"Inserted {len(rows)} monitoring data rows")
        return len(rows)
    except Exception as e:
        logger.error(f"Error inserting monitoring data: {e}")
        conn.rollback()
        return 0

# ═════════════════════════════════════════════════════════════════════════════
# Redis API functions
# ═════════════════════════════════════════════════════════════════════════════

def get_redis_keys(pattern="ic3:latest:*"):
    """Scan Redis for keys matching the pattern."""
    try:
        url = f"{REDIS_API_BASE}/api/redis/keys?pattern={quote(pattern)}"
        resp = requests.get(url, timeout=10)
        resp.raise_for_status()
        data = resp.json()
        keys = data.get("keys", [])
        logger.info(f"Found {len(keys)} Redis keys matching {pattern}")
        return keys
    except Exception as e:
        logger.error(f"Error scanning Redis keys: {e}")
        return []

def get_redis_value(key):
    """Fetch and parse a value from Redis by key."""
    try:
        url = f"{REDIS_API_BASE}/api/redis/key?k={quote(key)}"
        resp = requests.get(url, timeout=10)
        resp.raise_for_status()
        data = resp.json()
        value_str = data.get("value", "{}")
        return json.loads(value_str)
    except Exception as e:
        logger.error(f"Error fetching Redis key {key}: {e}")
        return None

# ═════════════════════════════════════════════════════════════════════════════
# Main worker
# ═════════════════════════════════════════════════════════════════════════════

def process_redis_to_monitoring():
    """Main worker: fetch Redis keys and ingest into PostgreSQL."""
    start_time = datetime.now()

    stats = {
        "redis_keys_scanned": 0,
        "assets_matched": 0,
        "assets_unmatched": 0,
        "rows_inserted": 0,
        "tags_skipped": 0,
    }

    unmatched_serials = []
    unmapped_tags = []

    try:
        conn = get_db_connection()
        ensure_unique_index(conn)

        # Step 1: Get all Redis keys
        keys = get_redis_keys()
        stats["redis_keys_scanned"] = len(keys)

        if not keys:
            logger.warning("No Redis keys found")
            return stats

        # Step 2-5: Process each key
        all_rows = []

        for key in keys:
            value = get_redis_value(key)
            if not value:
                logger.debug(f"Failed to parse value for key: {key}")
                continue

            # Extract fields from Redis payload
            redis_asset_id = value.get("asset_id")
            occurred_at = value.get("occurred_at")
            quality = value.get("quality", "UNKNOWN")
            alarm = value.get("alarm", "NORMAL")
            record_id = value.get("record_id", "")
            values_dict = value.get("values", {})

            # Step 3: Match to ic3_asset_master
            asset_id = get_asset_id(conn, redis_asset_id)
            if not asset_id:
                unmatched_serials.append(redis_asset_id)
                stats["assets_unmatched"] += 1
                logger.debug(f"Unmatched serial: {redis_asset_id}")
                continue

            stats["assets_matched"] += 1

            # Step 4: Match each parameter value to tag_map
            for param_tag_code, param_value in values_dict.items():
                tag_code, param_id, unit = get_tag_mapping(conn, asset_id, param_tag_code)
                if not tag_code:
                    unmapped_tags.append(f"{redis_asset_id}/{param_tag_code}")
                    stats["tags_skipped"] += 1
                    logger.debug(f"Unmapped tag: {redis_asset_id} / {param_tag_code}")
                    continue

                # Step 5: Prepare row for insertion
                all_rows.append((
                    asset_id,
                    tag_code,
                    param_id,
                    str(param_value),
                    float(param_value) if isinstance(param_value, (int, float)) else None,
                    unit,
                    occurred_at
                ))

            # Step 6: Insert common metadata fields
            if occurred_at:
                # QUALITY parameter
                quality_param_id = get_common_param_id(conn, "GOOD_UNCERTAIN_BAD_BAD_COMM_FAILURE_BAD_OUT_OF_SERVICE")
                if quality_param_id:
                    all_rows.append((
                        asset_id,
                        f"COMMON.QUALITY",
                        quality_param_id,
                        quality,
                        None,
                        "enum",
                        occurred_at
                    ))

                # ALARM parameter
                alarm_param_id = get_common_param_id(conn, "NORMAL_WARNING_ALARM_CRITICAL_STALE")
                if alarm_param_id:
                    all_rows.append((
                        asset_id,
                        f"COMMON.ALARM",
                        alarm_param_id,
                        alarm,
                        None,
                        "enum",
                        occurred_at
                    ))

                # RECORD_ID parameter
                record_id_param_id = get_common_param_id(conn, "UNIQUE_INGESTION_RECORD_ID")
                if record_id_param_id:
                    all_rows.append((
                        asset_id,
                        f"COMMON.RECORD_ID",
                        record_id_param_id,
                        record_id,
                        None,
                        "string",
                        occurred_at
                    ))

                # FIELD_TIMESTAMP parameter
                timestamp_param_id = get_common_param_id(conn, "FIELD_TIMESTAMP_FROM_DEVICE_SOURCE")
                if timestamp_param_id:
                    all_rows.append((
                        asset_id,
                        f"COMMON.TIMESTAMP",
                        timestamp_param_id,
                        occurred_at,
                        None,
                        "timestamp",
                        occurred_at
                    ))

        # Batch insert all rows
        inserted_count = insert_monitoring_rows(conn, all_rows)
        stats["rows_inserted"] = inserted_count

        conn.close()

    except Exception as e:
        logger.error(f"Worker error: {e}", exc_info=True)

    # Print summary
    elapsed = datetime.now() - start_time
    logger.info(f"\n[IC3 Worker] {start_time.strftime('%Y-%m-%d %H:%M:%S')}")
    logger.info(f"  Redis keys scanned : {stats['redis_keys_scanned']}")
    logger.info(f"  Assets matched     : {stats['assets_matched']}")
    logger.info(f"  Assets unmatched   : {stats['assets_unmatched']}  (see log for serial numbers)")
    logger.info(f"  Rows inserted      : {stats['rows_inserted']:,}")
    logger.info(f"  Tags skipped       : {stats['tags_skipped']}  (unmapped in ic3_asset_tag_map)")
    logger.info(f"  Elapsed time       : {elapsed.total_seconds():.2f}s")

    if unmatched_serials:
        logger.info(f"\nUnmatched serials: {', '.join(unmatched_serials[:10])}")
    if unmapped_tags:
        logger.info(f"\nUnmapped tags: {', '.join(unmapped_tags[:10])}")

    return stats

# ═════════════════════════════════════════════════════════════════════════════
# CLI
# ═════════════════════════════════════════════════════════════════════════════

def main():
    parser = argparse.ArgumentParser(description="Redis → IC3 monitoring data ingestion worker")
    parser.add_argument("--mode", default="once", choices=["once", "poll"],
                       help="Run mode: 'once' or 'poll'")
    parser.add_argument("--interval", type=int, default=30,
                       help="Polling interval in seconds (for --mode poll)")
    args = parser.parse_args()

    logger.info(f"IC3 Redis Worker starting — mode={args.mode}, interval={args.interval}s")
    logger.info(f"PostgreSQL: {PGUSER}@{PGHOST}:{PGPORT}/{PGDATABASE}")
    logger.info(f"Redis API: {REDIS_API_BASE}")

    if args.mode == "once":
        process_redis_to_monitoring()
    elif args.mode == "poll":
        try:
            while True:
                process_redis_to_monitoring()
                logger.info(f"Sleeping for {args.interval}s...")
                time.sleep(args.interval)
        except KeyboardInterrupt:
            logger.info("Worker stopped by user")

if __name__ == "__main__":
    main()
