from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional, Dict
import os
import pymysql
from contextlib import contextmanager

# LangChain + Anthropic
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import HumanMessage, SystemMessage

app = FastAPI(title="PsyMed AI API", version="0.4.0")

# ============================================
# MySQL 연결 설정
# ============================================
DB_CONFIG = {
    "host": "localhost",
    "user": "root",
    "password": "fastcoin12*",
    "database": "psychological_hub",
    "charset": "utf8mb4",
    "cursorclass": pymysql.cursors.DictCursor
}

@contextmanager
def get_db_connection():
    conn = pymysql.connect(**DB_CONFIG)
    try:
        yield conn
    finally:
        conn.close()

def load_drugs_from_db() -> Dict[int, str]:
    """약물 목록 로드: {drug_id: drug_name_en}"""
    with get_db_connection() as conn:
        with conn.cursor() as cursor:
            cursor.execute("SELECT drug_id, drug_name, drug_name_en FROM pm_drugs")
            drugs = {}
            for row in cursor.fetchall():
                # 영문명 우선, 없으면 한글명
                name = row['drug_name_en'] if row['drug_name_en'] else row['drug_name']
                drugs[row['drug_id']] = name
            return drugs

def load_contraindications_from_db() -> List[Dict]:
    """금기사항 로드"""
    with get_db_connection() as conn:
        with conn.cursor() as cursor:
            cursor.execute("""
                SELECT c.drug_id, d.drug_name_en, d.drug_name,
                       c.contraindication_type, c.contraindication, c.reason
                FROM pm_drug_contraindications c
                JOIN pm_drugs d ON c.drug_id = d.drug_id
            """)
            return cursor.fetchall()

def build_rules_from_db() -> Dict:
    """DB에서 규칙 구축"""
    rules = {}

    try:
        contras = load_contraindications_from_db()

        for row in contras:
            drug_name = row['drug_name_en'] if row['drug_name_en'] else row['drug_name']
            if not drug_name:
                continue

            if drug_name not in rules:
                rules[drug_name] = {}

            # 조건명 매핑 (DB 조건 -> 시스템 조건)
            condition = map_condition(row['contraindication'])
            if not condition:
                continue

            # 금기 유형 매핑
            level = "banned" if row['contraindication_type'] == '절대' else "caution"

            rules[drug_name][condition] = {
                "level": level,
                "reason": row['reason'] or row['contraindication']
            }

        return rules
    except Exception as e:
        print(f"DB 로드 오류: {e}")
        return {}

def map_condition(condition_name: str) -> Optional[str]:
    """DB 조건명을 시스템 조건으로 매핑 (엄격한 매핑)"""
    if not condition_name:
        return None

    # 제외할 패턴 (과민반응, 약물상호작용 등)
    exclude_patterns = [
        '과민반응', '과민증', '알레르기',
        '병용', '투여', '복용',
        'SSRI', 'SNRI', 'MAOI', 'MAO',
        '정신병', '치매', '섬망',
        '무호흡', '근무력',
        '남용', '의존',
        '혼돈', '흥분',
        '갈색세포종'
    ]

    for pattern in exclude_patterns:
        if pattern in condition_name:
            return None

    # 엄격한 매핑 (정확한 조건만)

    # 임신/수유
    if condition_name in ['임신', '임신 중', '임부']:
        return 'pregnancy'
    if condition_name in ['수유', '수유 중', '수유부']:
        return 'lactation'

    # 간기능 (중증만)
    if '중증 간' in condition_name or '간부전' in condition_name:
        return 'liver_severe'

    # 신기능 (중증만)
    if '중증 신' in condition_name or '신부전' in condition_name:
        return 'renal_severe'

    # 심장/심혈관
    if any(x in condition_name for x in ['심근경색', '심부전', '부정맥', '전도장애', 'QT']):
        return 'cardiac'

    # 당뇨
    if condition_name == '당뇨' or condition_name == '당뇨병':
        return 'diabetes'

    # 고혈압
    if condition_name in ['고혈압', '조절되지 않는 고혈압']:
        return 'hypertension'

    # 녹내장 (폐쇄각만)
    if '폐쇄각' in condition_name and '녹내장' in condition_name:
        return 'glaucoma'

    # 발작/경련
    if condition_name in ['발작', '경련', '간질', '뇌전증']:
        return 'seizure_history'

    # 호흡기 (중증만)
    if '중증 호흡' in condition_name or '호흡부전' in condition_name:
        return 'respiratory'

    return None

# ============================================
# LLM 설정
# ============================================
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", "")

llm = None
if ANTHROPIC_API_KEY:
    llm = ChatAnthropic(
        model="claude-sonnet-4-20250514",
        api_key=ANTHROPIC_API_KEY,
        max_tokens=1024
    )

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# ============================================
# L1: RULES DATABASE (DB에서 로드 + Fallback)
# ============================================

# 기본 규칙 (DB 연결 실패시 fallback)
FALLBACK_RULES = {
    "Valproate": {
        "pregnancy": {"level": "banned", "reason": "신경관결손 5-9%, 자폐위험"},
        "liver_severe": {"level": "banned", "reason": "간독성 위험"}
    },
    "Paroxetine": {
        "pregnancy": {"level": "banned", "reason": "1삼분기 심장기형 위험"},
        "diabetes": {"level": "caution", "reason": "체중증가로 혈당조절 악화 가능"}
    },
    "Lithium": {
        "pregnancy": {"level": "banned", "reason": "Ebstein 기형 위험"},
        "renal_severe": {"level": "banned", "reason": "중증 신기능저하시 금기"}
    },
    "Sertraline": {
        "pregnancy": {"level": "preferred", "reason": "임신중 안전성 데이터 가장 풍부"},
        "cardiac": {"level": "preferred", "reason": "심근경색 후 안전 (SADHART 연구)"}
    },
    "TCA": {
        "cardiac": {"level": "banned", "reason": "부정맥 위험"},
        "glaucoma": {"level": "banned", "reason": "녹내장 악화"}
    }
}

# DB에서 규칙 로드 (시작시 1회)
def get_drug_rules() -> Dict:
    """DB에서 규칙 로드, 실패시 fallback 사용"""
    try:
        db_rules = build_rules_from_db()
        if db_rules:
            # DB 규칙과 fallback 규칙 병합 (DB 우선)
            merged = {**FALLBACK_RULES}
            for drug, conditions in db_rules.items():
                if drug not in merged:
                    merged[drug] = {}
                merged[drug].update(conditions)
            print(f"DB에서 {len(db_rules)}개 약물 규칙 로드됨")
            return merged
    except Exception as e:
        print(f"DB 로드 실패, fallback 사용: {e}")
    return FALLBACK_RULES

# 전역 규칙 (서버 시작시 로드)
DRUG_RULES = get_drug_rules()

# ============================================
# MODELS
# ============================================

class PatientInfo(BaseModel):
    diagnosis: str = ""
    age: int = 30
    pregnancy: bool = False
    trimester: Optional[int] = None
    lactation: bool = False
    liver_disease: str = "none"
    renal_function: str = "normal"
    diabetes: bool = False
    cardiac: bool = False
    hypertension: bool = False
    hypothyroidism: bool = False
    seizure_history: bool = False
    glaucoma: bool = False
    eating_disorder: bool = False
    current_medications: List[str] = []

class ExplainRequest(BaseModel):
    patient: PatientInfo
    question: str = ""
    rules_result: Optional[dict] = None

# ============================================
# HELPER FUNCTIONS
# ============================================

def get_patient_conditions(patient: PatientInfo) -> List[str]:
    conditions = []
    if patient.pregnancy:
        conditions.append("pregnancy")
    if patient.lactation:
        conditions.append("lactation")
    if patient.liver_disease == "mild":
        conditions.append("liver_mild")
    elif patient.liver_disease == "moderate":
        conditions.append("liver_moderate")
    elif patient.liver_disease == "severe":
        conditions.append("liver_severe")
    if patient.renal_function == "mild":
        conditions.append("renal_mild")
    elif patient.renal_function == "moderate":
        conditions.append("renal_moderate")
    elif patient.renal_function == "severe":
        conditions.append("renal_severe")
    if patient.diabetes:
        conditions.append("diabetes")
    if patient.cardiac:
        conditions.append("cardiac")
    if patient.hypertension:
        conditions.append("hypertension")
    if patient.age >= 65:
        conditions.append("elderly")
    if patient.hypothyroidism:
        conditions.append("hypothyroidism")
    if patient.seizure_history:
        conditions.append("seizure_history")
    if patient.glaucoma:
        conditions.append("glaucoma")
    if patient.eating_disorder:
        conditions.append("eating_disorder")
    return conditions

def check_drug_for_conditions(drug: str, conditions: List[str]) -> List[dict]:
    results = []
    if drug not in DRUG_RULES:
        return results

    rules = DRUG_RULES[drug]
    for condition in conditions:
        if condition in rules:
            rule = rules[condition]
            rule_ko = rules.get(f"{condition}_ko", rule)
            results.append({
                "drug": drug,
                "condition": condition,
                "level": rule["level"],
                "reason": rule["reason"],
                "reason_ko": rule_ko["reason"]
            })
    return results

# ============================================
# API ENDPOINTS
# ============================================

@app.get("/health")
def health():
    return {"status": "ok", "version": "0.2.0", "drugs_count": len(DRUG_RULES)}

@app.get("/api/drugs")
def list_drugs():
    return {"drugs": list(DRUG_RULES.keys()), "count": len(DRUG_RULES)}

def group_by_drug(warnings: List[dict]) -> List[dict]:
    """약물별로 경고를 그룹화"""
    drug_groups = {}
    for w in warnings:
        drug = w["drug"]
        if drug not in drug_groups:
            drug_groups[drug] = {
                "drug": drug,
                "max_level": w["level"],
                "reasons": []
            }
        drug_groups[drug]["reasons"].append({
            "condition": w["condition"],
            "level": w["level"],
            "reason": w["reason"]
        })
        # 최고 심각도 업데이트
        level_order = {"banned": 3, "avoid": 2, "caution": 1}
        if level_order.get(w["level"], 0) > level_order.get(drug_groups[drug]["max_level"], 0):
            drug_groups[drug]["max_level"] = w["level"]

    return list(drug_groups.values())

@app.post("/api/check-safety")
def check_safety(patient: PatientInfo):
    conditions = get_patient_conditions(patient)
    all_warnings = []
    preferred = []
    safe_list = []

    for drug in DRUG_RULES.keys():
        drug_results = check_drug_for_conditions(drug, conditions)
        for result in drug_results:
            if result["level"] == "banned":
                all_warnings.append(result)
            elif result["level"] == "avoid":
                all_warnings.append(result)
            elif result["level"] == "caution":
                all_warnings.append(result)
            elif result["level"] == "preferred":
                preferred.append(result)
            elif result["level"] == "safe":
                safe_list.append(result)

    banned = [w for w in all_warnings if w["level"] == "banned"]
    avoid = [w for w in all_warnings if w["level"] == "avoid"]
    caution = [w for w in all_warnings if w["level"] == "caution"]

    # 약물별 그룹화 (금기 > 회피 > 주의 순으로 분류)
    all_by_drug = group_by_drug(all_warnings)
    banned_drugs = [d for d in all_by_drug if d["max_level"] == "banned"]
    avoid_drugs = [d for d in all_by_drug if d["max_level"] == "avoid"]
    caution_drugs = [d for d in all_by_drug if d["max_level"] == "caution"]

    return {
        "conditions": conditions,
        # 기존 형식 (하위 호환)
        "banned": banned,
        "avoid": avoid,
        "caution": caution,
        "preferred": preferred,
        "safe": safe_list,
        # 새로운 그룹화 형식
        "grouped": {
            "banned": banned_drugs,
            "avoid": avoid_drugs,
            "caution": caution_drugs,
            "preferred": group_by_drug(preferred)
        },
        "summary": {
            "banned_count": len(banned),
            "avoid_count": len(avoid),
            "caution_count": len(caution),
            "preferred_count": len(preferred),
            "safe_count": len(safe_list),
            # 약물 기준 개수
            "banned_drugs": len(banned_drugs),
            "avoid_drugs": len(avoid_drugs),
            "caution_drugs": len(caution_drugs)
        }
    }

@app.post("/api/check-drug")
def check_single_drug(drug_name: str, patient: PatientInfo):
    conditions = get_patient_conditions(patient)
    results = check_drug_for_conditions(drug_name, conditions)

    is_safe = len([r for r in results if r["level"] in ["banned", "avoid"]]) == 0

    return {
        "drug": drug_name,
        "conditions": conditions,
        "warnings": results,
        "is_safe": is_safe
    }

@app.get("/api/conditions")
def list_conditions():
    return {
        "conditions": [
            {"key": "pregnancy", "label": "임신", "type": "bool"},
            {"key": "lactation", "label": "수유", "type": "bool"},
            {"key": "liver_disease", "label": "간질환", "type": "select", "options": ["none", "mild", "moderate", "severe"]},
            {"key": "renal_function", "label": "신기능", "type": "select", "options": ["normal", "mild", "moderate", "severe"]},
            {"key": "diabetes", "label": "당뇨", "type": "bool"},
            {"key": "cardiac", "label": "심장질환", "type": "bool"},
            {"key": "hypertension", "label": "고혈압", "type": "bool"},
            {"key": "hypothyroidism", "label": "갑상선저하증", "type": "bool"},
            {"key": "seizure_history", "label": "발작력", "type": "bool"},
            {"key": "glaucoma", "label": "녹내장", "type": "bool"},
            {"key": "eating_disorder", "label": "섭식장애", "type": "bool"},
            {"key": "age", "label": "나이", "type": "number"}
        ]
    }

# ============================================
# L2/L3: AI 설명 엔드포인트
# ============================================

SYSTEM_PROMPT = """당신은 정신건강의학과 전문의를 보조하는 약물 안전성 전문가입니다.
환자 정보와 규칙 기반 분석 결과를 바탕으로 임상적 설명을 제공합니다.

원칙:
1. 규칙 엔진의 금기/주의 판정은 절대적으로 존중합니다 (L1 = 사실)
2. 복합 상황에서는 우선순위와 대안을 제시합니다 (L2 = 판단)
3. 의사가 이해하기 쉽게 근거와 함께 설명합니다 (L3 = 설명)
4. 한국어로 답변하고, 약물명은 영문 그대로 사용합니다
5. 간결하게 핵심만 전달합니다 (200자 이내 권장)
"""

@app.post("/api/explain")
async def explain_with_ai(request: ExplainRequest):
    if not llm:
        return {
            "success": False,
            "error": "API key not configured",
            "explanation": None
        }

    # 규칙 결과가 없으면 먼저 체크
    if not request.rules_result:
        conditions = get_patient_conditions(request.patient)
        all_warnings = []
        preferred = []
        safe_list = []

        for drug in DRUG_RULES.keys():
            drug_results = check_drug_for_conditions(drug, conditions)
            for result in drug_results:
                if result["level"] in ["banned", "avoid", "caution"]:
                    all_warnings.append(result)
                elif result["level"] == "preferred":
                    preferred.append(result)
                elif result["level"] == "safe":
                    safe_list.append(result)

        rules_result = {
            "conditions": conditions,
            "banned": [w for w in all_warnings if w["level"] == "banned"],
            "avoid": [w for w in all_warnings if w["level"] == "avoid"],
            "caution": [w for w in all_warnings if w["level"] == "caution"],
            "preferred": preferred,
            "safe": safe_list
        }
    else:
        rules_result = request.rules_result

    # 환자 정보 요약
    patient = request.patient
    patient_summary = f"환자: {patient.age}세"
    if patient.pregnancy:
        patient_summary += ", 임신"
    if patient.lactation:
        patient_summary += ", 수유중"
    if patient.liver_disease != "none":
        patient_summary += f", 간기능 {patient.liver_disease}"
    if patient.renal_function != "normal":
        patient_summary += f", 신기능 {patient.renal_function}"
    if patient.diabetes:
        patient_summary += ", 당뇨"
    if patient.cardiac:
        patient_summary += ", 심장질환"
    if patient.diagnosis:
        patient_summary += f", 진단: {patient.diagnosis}"

    # 규칙 결과 요약
    rules_summary = "규칙 분석 결과:\n"
    if rules_result.get("banned"):
        items = [r['drug'] + '(' + r['reason_ko'] + ')' for r in rules_result['banned']]
        rules_summary += "- 금기: " + ', '.join(items) + "\n"
    if rules_result.get("avoid"):
        items = [r['drug'] + '(' + r['reason_ko'] + ')' for r in rules_result['avoid']]
        rules_summary += "- 회피: " + ', '.join(items) + "\n"
    if rules_result.get("caution"):
        items = [r['drug'] + '(' + r['reason_ko'] + ')' for r in rules_result['caution']]
        rules_summary += "- 주의: " + ', '.join(items) + "\n"
    if rules_result.get("preferred"):
        items = [r['drug'] + '(' + r['reason_ko'] + ')' for r in rules_result['preferred']]
        rules_summary += "- 권장: " + ', '.join(items) + "\n"

    # 질문 구성
    question = request.question if request.question else "이 환자에게 적합한 약물 선택에 대해 종합적으로 설명해주세요."

    user_message = f"""{patient_summary}

{rules_summary}

질문: {question}"""

    try:
        messages = [
            SystemMessage(content=SYSTEM_PROMPT),
            HumanMessage(content=user_message)
        ]

        response = llm.invoke(messages)

        return {
            "success": True,
            "explanation": response.content,
            "patient_summary": patient_summary,
            "rules_summary": rules_summary
        }
    except Exception as e:
        return {
            "success": False,
            "error": str(e),
            "explanation": None
        }

@app.get("/api/ai-status")
def ai_status():
    return {
        "llm_available": llm is not None,
        "model": "claude-sonnet-4-20250514" if llm else None
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
