"""Duomenų bazės sluoksnis (SQLAlchemy 2.x).

Apibrėžia modelius (countries, sources, trends, trend_history, trend_links)
bei pateikia patogų DatabaseHandler su pagrindinėmis operacijomis.
"""

from __future__ import annotations

import datetime as dt
from typing import Iterable, List, Optional, Sequence

from sqlalchemy import (
    BigInteger, Boolean, Column, DateTime, Enum, ForeignKey, Index,
    Integer, Numeric, String, Text, UniqueConstraint, create_engine, func, select,
)
from sqlalchemy.dialects.mysql import INTEGER as MyInt, BIGINT as MyBigInt
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, Session, sessionmaker

from .logger import get_logger

log = get_logger(__name__)


class Base(DeclarativeBase):
    pass


# ---------------------------------------------------------------------------
# Modeliai
# ---------------------------------------------------------------------------
class Country(Base):
    __tablename__ = "countries"

    id            = mapped_column(MyInt(unsigned=True), primary_key=True, autoincrement=True)
    iso_code      = mapped_column(String(2), nullable=False, unique=True)
    iso_code3     = mapped_column(String(3))
    name_lt       = mapped_column(String(120), nullable=False)
    name_en       = mapped_column(String(120), nullable=False)
    timezone      = mapped_column(String(64))
    primary_language = mapped_column(String(5))
    pytrends_geo  = mapped_column(String(16))
    enabled       = mapped_column(Boolean, default=True, nullable=False)


class Source(Base):
    __tablename__ = "sources"

    id           = mapped_column(MyInt(unsigned=True), primary_key=True, autoincrement=True)
    source_name  = mapped_column(String(120), nullable=False)
    source_type  = mapped_column(String(40), nullable=False)
    url          = mapped_column(String(500))
    country_id   = mapped_column(MyInt(unsigned=True), ForeignKey("countries.id", ondelete="SET NULL"))
    last_fetch   = mapped_column(DateTime)
    last_status  = mapped_column(String(40))
    last_error   = mapped_column(Text)
    enabled      = mapped_column(Boolean, default=True, nullable=False)

    __table_args__ = (
        UniqueConstraint("source_name", "country_id", name="uq_source"),
        Index("idx_source_type", "source_type"),
        Index("idx_country", "country_id"),
    )


class Trend(Base):
    __tablename__ = "trends"

    id            = mapped_column(MyBigInt(unsigned=True), primary_key=True, autoincrement=True)
    country_id    = mapped_column(MyInt(unsigned=True), ForeignKey("countries.id", ondelete="CASCADE"), nullable=False)
    source_id     = mapped_column(MyInt(unsigned=True), ForeignKey("sources.id", ondelete="CASCADE"), nullable=False)
    category      = mapped_column(String(40))
    keyword       = mapped_column(String(255), nullable=False)
    keyword_raw   = mapped_column(String(500))
    language      = mapped_column(String(5))
    translation_en = mapped_column(String(500))
    score         = mapped_column(Numeric(10, 4))
    volume        = mapped_column(BigInteger)
    sentiment     = mapped_column(Numeric(5, 4))
    first_seen    = mapped_column(DateTime, nullable=False)
    last_updated  = mapped_column(DateTime, nullable=False)
    trend_type    = mapped_column(Enum("realtime", "hourly", "daily", name="trend_type"),
                                  nullable=False, default="daily")
    description   = mapped_column(Text)
    url           = mapped_column(String(700))

    __table_args__ = (
        UniqueConstraint("country_id", "keyword", "trend_type", name="uq_country_keyword_type"),
        Index("idx_country_type_updated", "country_id", "trend_type", "last_updated"),
        Index("idx_keyword", "keyword"),
        Index("idx_category", "category"),
        Index("idx_last_updated", "last_updated"),
    )


class TrendHistory(Base):
    __tablename__ = "trend_history"

    id            = mapped_column(MyBigInt(unsigned=True), primary_key=True, autoincrement=True)
    trend_id      = mapped_column(MyBigInt(unsigned=True),
                                  ForeignKey("trends.id", ondelete="CASCADE"), nullable=False)
    timestamp     = mapped_column(DateTime, nullable=False)
    rank_position = mapped_column(Integer)
    volume        = mapped_column(BigInteger)
    sentiment     = mapped_column(Numeric(5, 4))
    description   = mapped_column(Text)

    __table_args__ = (
        Index("idx_trend_time", "trend_id", "timestamp"),
    )


class TrendLink(Base):
    __tablename__ = "trend_links"

    id          = mapped_column(MyBigInt(unsigned=True), primary_key=True, autoincrement=True)
    trend_id_a  = mapped_column(MyBigInt(unsigned=True),
                                ForeignKey("trends.id", ondelete="CASCADE"), nullable=False)
    trend_id_b  = mapped_column(MyBigInt(unsigned=True),
                                ForeignKey("trends.id", ondelete="CASCADE"), nullable=False)
    similarity  = mapped_column(Numeric(5, 4), nullable=False)
    method      = mapped_column(String(40))

    __table_args__ = (
        UniqueConstraint("trend_id_a", "trend_id_b", name="uq_pair"),
    )


# ---------------------------------------------------------------------------
# Handler
# ---------------------------------------------------------------------------
class DatabaseHandler:
    """Aukšto lygio API darbui su DB."""

    def __init__(self, cfg: dict):
        url = (
            f"mysql+pymysql://{cfg['user']}:{cfg['password']}"
            f"@{cfg['host']}:{cfg.get('port', 3306)}/{cfg['database']}"
            f"?charset={cfg.get('charset', 'utf8mb4')}"
        )
        self.engine = create_engine(
            url,
            pool_size=cfg.get("pool_size", 5),
            pool_recycle=cfg.get("pool_recycle", 1800),
            pool_pre_ping=True,
            future=True,
        )
        self.SessionLocal = sessionmaker(self.engine, expire_on_commit=False, future=True)

    # ---- schemos sukūrimas ------------------------------------------------
    def create_schema(self) -> None:
        Base.metadata.create_all(self.engine)
        log.info("DB schema užtikrinta (create_all).")

    # ---- sesijos helperis -------------------------------------------------
    def session(self) -> Session:
        return self.SessionLocal()

    # ---- countries --------------------------------------------------------
    def upsert_country(self, *, iso_code: str, iso_code3: Optional[str],
                       name_lt: str, name_en: str, timezone: Optional[str],
                       primary_language: Optional[str],
                       pytrends_geo: Optional[str]) -> Country:
        with self.session() as s:
            obj = s.scalar(select(Country).where(Country.iso_code == iso_code))
            if obj is None:
                obj = Country(
                    iso_code=iso_code, iso_code3=iso_code3,
                    name_lt=name_lt, name_en=name_en, timezone=timezone,
                    primary_language=primary_language, pytrends_geo=pytrends_geo,
                )
                s.add(obj)
            else:
                obj.iso_code3 = iso_code3 or obj.iso_code3
                obj.name_lt = name_lt
                obj.name_en = name_en
                obj.timezone = timezone or obj.timezone
                obj.primary_language = primary_language or obj.primary_language
                obj.pytrends_geo = pytrends_geo or obj.pytrends_geo
            s.commit()
            s.refresh(obj)
            return obj

    def get_country(self, iso_code: str) -> Optional[Country]:
        with self.session() as s:
            return s.scalar(select(Country).where(Country.iso_code == iso_code.upper()))

    def get_enabled_countries(self, iso_filter: Sequence[str] | None = None) -> List[Country]:
        with self.session() as s:
            q = select(Country).where(Country.enabled.is_(True))
            if iso_filter and "ALL" not in iso_filter:
                q = q.where(Country.iso_code.in_([c.upper() for c in iso_filter]))
            return list(s.scalars(q))

    # ---- sources ----------------------------------------------------------
    def upsert_source(self, *, source_name: str, source_type: str,
                      url: Optional[str], country_id: Optional[int] = None) -> Source:
        with self.session() as s:
            q = select(Source).where(
                Source.source_name == source_name,
                Source.country_id == country_id,
            )
            obj = s.scalar(q)
            if obj is None:
                obj = Source(source_name=source_name, source_type=source_type,
                             url=url, country_id=country_id)
                s.add(obj)
            else:
                obj.source_type = source_type
                obj.url = url or obj.url
            s.commit()
            s.refresh(obj)
            return obj

    def mark_source_fetch(self, source_id: int, status: str,
                          error: Optional[str] = None) -> None:
        with self.session() as s:
            src = s.get(Source, source_id)
            if src:
                src.last_fetch = dt.datetime.utcnow()
                src.last_status = status
                src.last_error = (error or "")[:65000]
                s.commit()

    # ---- trends -----------------------------------------------------------
    def upsert_trend(self, **kw) -> Trend:
        """Įdeda arba atnaujina trendą pagal (country_id, keyword, trend_type)."""
        now = dt.datetime.utcnow()
        with self.session() as s:
            q = select(Trend).where(
                Trend.country_id == kw["country_id"],
                Trend.keyword == kw["keyword"],
                Trend.trend_type == kw["trend_type"],
            )
            obj = s.scalar(q)
            if obj is None:
                obj = Trend(
                    country_id=kw["country_id"],
                    source_id=kw["source_id"],
                    category=kw.get("category"),
                    keyword=kw["keyword"],
                    keyword_raw=kw.get("keyword_raw"),
                    language=kw.get("language"),
                    translation_en=kw.get("translation_en"),
                    score=kw.get("score"),
                    volume=kw.get("volume"),
                    sentiment=kw.get("sentiment"),
                    first_seen=now,
                    last_updated=now,
                    trend_type=kw["trend_type"],
                    description=kw.get("description"),
                    url=kw.get("url"),
                )
                s.add(obj)
            else:
                # atnaujinimo politika: paliekam pirmąjį source_id, kad būtų pastovus
                obj.last_updated = now
                # atnaujinam tik tas reikšmes, kurios atėjo
                for fld in ("category", "score", "volume", "description", "url",
                            "translation_en", "language", "sentiment", "keyword_raw"):
                    val = kw.get(fld)
                    if val is not None:
                        setattr(obj, fld, val)
            s.commit()
            s.refresh(obj)

            # istorijos įrašas
            hist = TrendHistory(
                trend_id=obj.id,
                timestamp=now,
                rank_position=kw.get("rank"),
                volume=kw.get("volume"),
                sentiment=kw.get("sentiment"),
                description=kw.get("description"),
            )
            s.add(hist)
            s.commit()
            return obj

    def recent_trends(self, country_id: int, trend_type: str,
                      hours: int = 24, limit: int = 50) -> List[Trend]:
        since = dt.datetime.utcnow() - dt.timedelta(hours=hours)
        with self.session() as s:
            q = (select(Trend)
                 .where(Trend.country_id == country_id,
                        Trend.trend_type == trend_type,
                        Trend.last_updated >= since)
                 .order_by(Trend.last_updated.desc())
                 .limit(limit))
            return list(s.scalars(q))

    def all_recent_trends(self, hours: int = 24, limit: int = 1000) -> List[Trend]:
        since = dt.datetime.utcnow() - dt.timedelta(hours=hours)
        with self.session() as s:
            q = (select(Trend)
                 .where(Trend.last_updated >= since)
                 .order_by(Trend.last_updated.desc())
                 .limit(limit))
            return list(s.scalars(q))

    def add_trend_link(self, a_id: int, b_id: int, similarity: float,
                       method: str = "tfidf") -> None:
        if a_id == b_id:
            return
        if a_id > b_id:
            a_id, b_id = b_id, a_id
        with self.session() as s:
            existing = s.scalar(select(TrendLink).where(
                TrendLink.trend_id_a == a_id, TrendLink.trend_id_b == b_id))
            if existing is None:
                s.add(TrendLink(trend_id_a=a_id, trend_id_b=b_id,
                                similarity=similarity, method=method))
                s.commit()
