from __future__ import annotations

import base64
import functools
import json
import logging
from collections.abc import Mapping
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Optional,
    TypeVar,
    Union,
)

from typing_extensions import TypedDict

from langsmith import client as ls_client
from langsmith import run_helpers
from langsmith._internal._beta_decorator import warn_beta
from langsmith.schemas import InputTokenDetails, OutputTokenDetails, UsageMetadata

if TYPE_CHECKING:
    from google import genai  # type: ignore[import-untyped, attr-defined]

C = TypeVar("C", bound=Union["genai.Client", Any])
logger = logging.getLogger(__name__)


def _strip_none(d: dict) -> dict:
    """Remove `None` values from dictionary."""
    return {k: v for k, v in d.items() if v is not None}


def _convert_config_for_tracing(kwargs: dict) -> None:
    """Convert `GenerateContentConfig` to `dict` for LangSmith compatibility."""
    if "config" in kwargs and not isinstance(kwargs["config"], dict):
        kwargs["config"] = vars(kwargs["config"])


def _process_gemini_inputs(inputs: dict) -> dict:
    r"""Process Gemini inputs to normalize them for LangSmith tracing.

    Example:
        ```txt
        {"contents": "Hello", "model": "gemini-pro"}
        → {"messages": [{"role": "user", "content": "Hello"}], "model": "gemini-pro"}
        {"contents": [{"role": "user", "parts": [{"text": "What is AI?"}]}], "model": "gemini-pro"}
        → {"messages": [{"role": "user", "content": "What is AI?"}], "model": "gemini-pro"}
        ```
    """  # noqa: E501
    # If contents is not present or not in list format, return as-is
    contents = inputs.get("contents")
    if not contents:
        return inputs

    # Handle string input (simple case)
    if isinstance(contents, str):
        return {
            "messages": [{"role": "user", "content": contents}],
            "model": inputs.get("model"),
            **({k: v for k, v in inputs.items() if k not in ("contents", "model")}),
        }

    # Handle list of content objects (multimodal case)
    if isinstance(contents, list):
        # Check if it's a simple list of strings
        if all(isinstance(item, str) for item in contents):
            # Each string becomes a separate user message (matches Gemini's behavior)
            return {
                "messages": [{"role": "user", "content": item} for item in contents],
                "model": inputs.get("model"),
                **({k: v for k, v in inputs.items() if k not in ("contents", "model")}),
            }
        # Handle complex multimodal case
        messages = []
        for content in contents:
            if isinstance(content, dict):
                role = content.get("role", "user")
                parts = content.get("parts", [])

                # Extract text and other parts
                text_parts = []
                content_parts = []

                for part in parts:
                    if isinstance(part, dict):
                        # Handle text parts
                        if "text" in part and part["text"]:
                            text_parts.append(part["text"])
                            content_parts.append({"type": "text", "text": part["text"]})
                        # Handle inline data (images)
                        elif "inline_data" in part:
                            inline_data = part["inline_data"]
                            mime_type = inline_data.get("mime_type", "image/jpeg")
                            data = inline_data.get("data", b"")

                            # Convert bytes to base64 string if needed
                            if isinstance(data, bytes):
                                data_b64 = base64.b64encode(data).decode("utf-8")
                            else:
                                data_b64 = data  # Already a string

                            content_parts.append(
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": f"data:{mime_type};base64,{data_b64}",
                                        "detail": "high",
                                    },
                                }
                            )
                        # Handle function responses
                        elif "functionResponse" in part:
                            function_response = part["functionResponse"]
                            content_parts.append(
                                {
                                    "type": "function_response",
                                    "function_response": {
                                        "name": function_response.get("name"),
                                        "response": function_response.get(
                                            "response", {}
                                        ),
                                    },
                                }
                            )
                        # Handle function calls (for conversation history)
                        elif "function_call" in part or "functionCall" in part:
                            function_call = part.get("function_call") or part.get(
                                "functionCall"
                            )

                            if function_call is not None:
                                # Normalize to dict (FunctionCall is a Pydantic model)
                                if not isinstance(function_call, dict):
                                    function_call = function_call.to_dict()

                                content_parts.append(
                                    {
                                        "type": "function_call",
                                        "function_call": {
                                            "id": function_call.get("id"),
                                            "name": function_call.get("name"),
                                            "arguments": function_call.get("args", {}),
                                        },
                                    }
                                )
                    elif isinstance(part, str):
                        # Handle simple string parts
                        text_parts.append(part)
                        content_parts.append({"type": "text", "text": part})

                # If only text parts, use simple string format
                if content_parts and all(
                    p.get("type") == "text" for p in content_parts
                ):
                    message_content: Union[str, list[dict[str, Any]]] = "\n".join(
                        text_parts
                    )
                else:
                    message_content = content_parts if content_parts else ""

                messages.append({"role": role, "content": message_content})
        return {
            "messages": messages,
            "model": inputs.get("model"),
            **({k: v for k, v in inputs.items() if k not in ("contents", "model")}),
        }

    # Fallback: return original inputs
    return inputs


def _infer_invocation_params(kwargs: dict) -> dict:
    """Extract invocation parameters for tracing."""
    stripped = _strip_none(kwargs)
    config = stripped.get("config", {})

    # Handle both dict config and GenerateContentConfig object
    if hasattr(config, "temperature"):
        temperature = config.temperature
        max_tokens = getattr(config, "max_output_tokens", None)
        stop = getattr(config, "stop_sequences", None)
    else:
        temperature = config.get("temperature")
        max_tokens = config.get("max_output_tokens")
        stop = config.get("stop_sequences")

    return {
        "ls_provider": "google",
        "ls_model_type": "chat",
        "ls_model_name": stripped.get("model"),
        "ls_temperature": temperature,
        "ls_max_tokens": max_tokens,
        "ls_stop": stop,
    }


def _create_usage_metadata(gemini_usage_metadata: dict) -> UsageMetadata:
    """Convert Gemini usage metadata to LangSmith format."""
    prompt_token_count = gemini_usage_metadata.get("prompt_token_count") or 0
    candidates_token_count = gemini_usage_metadata.get("candidates_token_count") or 0
    cached_content_token_count = (
        gemini_usage_metadata.get("cached_content_token_count") or 0
    )
    thoughts_token_count = gemini_usage_metadata.get("thoughts_token_count") or 0
    total_token_count = (
        gemini_usage_metadata.get("total_token_count")
        or prompt_token_count + candidates_token_count
    )

    input_token_details: dict = {}
    if cached_content_token_count:
        input_token_details["cache_read"] = cached_content_token_count

    output_token_details: dict = {}
    if thoughts_token_count:
        output_token_details["reasoning"] = thoughts_token_count

    return UsageMetadata(
        input_tokens=prompt_token_count,
        output_tokens=candidates_token_count,
        total_tokens=total_token_count,
        input_token_details=InputTokenDetails(
            **{k: v for k, v in input_token_details.items() if v is not None}
        ),
        output_token_details=OutputTokenDetails(
            **{k: v for k, v in output_token_details.items() if v is not None}
        ),
    )


def _process_generate_content_response(response: Any) -> dict:
    """Process Gemini response for tracing."""
    try:
        # Convert response to dictionary
        if hasattr(response, "to_dict"):
            rdict = response.to_dict()
        elif hasattr(response, "model_dump"):
            rdict = response.model_dump()
        else:
            rdict = {"text": getattr(response, "text", str(response))}

        # Extract content from candidates if available
        content_result = ""
        content_parts = []
        finish_reason: Optional[str] = None
        if "candidates" in rdict and rdict["candidates"]:
            candidate = rdict["candidates"][0]
            if "content" in candidate:
                content = candidate["content"]
                if "parts" in content and content["parts"]:
                    for part in content["parts"]:
                        # Handle text parts
                        if "text" in part and part["text"]:
                            content_result += part["text"]
                            content_parts.append({"type": "text", "text": part["text"]})
                        # Handle inline data (images) in response
                        elif "inline_data" in part and part["inline_data"] is not None:
                            inline_data = part["inline_data"]
                            mime_type = inline_data.get("mime_type", "image/jpeg")
                            data = inline_data.get("data", b"")

                            # Convert bytes to base64 string if needed
                            if isinstance(data, bytes):
                                data_b64 = base64.b64encode(data).decode("utf-8")
                            else:
                                data_b64 = data  # Already a string

                            content_parts.append(
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": f"data:{mime_type};base64,{data_b64}",
                                        "detail": "high",
                                    },
                                }
                            )
                        # Handle function calls in response
                        elif "function_call" in part or "functionCall" in part:
                            function_call = part.get("function_call") or part.get(
                                "functionCall"
                            )

                            if function_call is not None:
                                # Normalize to dict (FunctionCall is a Pydantic model)
                                if not isinstance(function_call, dict):
                                    function_call = function_call.to_dict()

                                content_parts.append(
                                    {
                                        "type": "function_call",
                                        "function_call": {
                                            "id": function_call.get("id"),
                                            "name": function_call.get("name"),
                                            "arguments": function_call.get("args", {}),
                                        },
                                    }
                                )
                if "finish_reason" in candidate and candidate["finish_reason"]:
                    finish_reason = candidate["finish_reason"]
        elif "text" in rdict:
            content_result = rdict["text"]
            content_parts.append({"type": "text", "text": content_result})

        # Build chat-like response format - use OpenAI-compatible format for tool calls
        tool_calls = [p for p in content_parts if p.get("type") == "function_call"]
        if tool_calls:
            # OpenAI-compatible format for LangSmith UI
            result = {
                "content": content_result or None,
                "role": "assistant",
                "finish_reason": finish_reason,
                "tool_calls": [
                    {
                        "id": tc["function_call"].get("id") or f"call_{i}",
                        "type": "function",
                        "index": i,
                        "function": {
                            "name": tc["function_call"]["name"],
                            "arguments": json.dumps(tc["function_call"]["arguments"]),
                        },
                    }
                    for i, tc in enumerate(tool_calls)
                ],
            }
        elif len(content_parts) > 1 or (
            content_parts and content_parts[0]["type"] != "text"
        ):
            # Use structured format for mixed non-tool content
            result = {
                "content": content_parts,
                "role": "assistant",
                "finish_reason": finish_reason,
            }
        else:
            # Use simple string format for text-only responses
            result = {
                "content": content_result,
                "role": "assistant",
                "finish_reason": finish_reason,
            }

        # Extract and convert usage metadata
        usage_metadata = rdict.get("usage_metadata")
        usage_dict: UsageMetadata = UsageMetadata(
            input_tokens=0, output_tokens=0, total_tokens=0
        )
        if usage_metadata:
            usage_dict = _create_usage_metadata(usage_metadata)
            # Add usage_metadata to both run.extra AND outputs
            current_run = run_helpers.get_current_run_tree()
            if current_run:
                try:
                    meta = current_run.extra.setdefault("metadata", {}).setdefault(
                        "usage_metadata", {}
                    )
                    meta.update(usage_dict)
                    current_run.patch()
                except Exception as e:
                    logger.warning(f"Failed to update usage metadata: {e}")

        # Return in a format that avoids stringification by LangSmith
        if result.get("tool_calls"):
            # For responses with tool calls, return structured format
            return {
                "content": result["content"],
                "role": "assistant",
                "finish_reason": finish_reason,
                "tool_calls": result["tool_calls"],
                "usage_metadata": usage_dict,
            }
        else:
            # For simple text responses, return minimal structure with usage metadata
            if isinstance(result["content"], str):
                return {
                    "content": result["content"],
                    "role": "assistant",
                    "finish_reason": finish_reason,
                    "usage_metadata": usage_dict,
                }
            else:
                # For multimodal content, return structured format with usage metadata
                return {
                    "content": result["content"],
                    "role": "assistant",
                    "finish_reason": finish_reason,
                    "usage_metadata": usage_dict,
                }
    except Exception as e:
        logger.debug(f"Error processing Gemini response: {e}")
        return {"output": response}


def _reduce_generate_content_chunks(all_chunks: list) -> dict:
    """Reduce streaming chunks into a single response."""
    if not all_chunks:
        return {
            "content": "",
            "usage_metadata": UsageMetadata(
                input_tokens=0, output_tokens=0, total_tokens=0
            ),
        }

    # Accumulate text from all chunks
    full_text = ""
    last_chunk = None

    for chunk in all_chunks:
        try:
            if hasattr(chunk, "text") and chunk.text:
                full_text += chunk.text
            last_chunk = chunk
        except Exception as e:
            logger.debug(f"Error processing chunk: {e}")

    # Extract usage metadata from the last chunk
    usage_metadata: UsageMetadata = UsageMetadata(
        input_tokens=0, output_tokens=0, total_tokens=0
    )
    if last_chunk:
        try:
            if hasattr(last_chunk, "usage_metadata") and last_chunk.usage_metadata:
                if hasattr(last_chunk.usage_metadata, "to_dict"):
                    usage_dict = last_chunk.usage_metadata.to_dict()
                elif hasattr(last_chunk.usage_metadata, "model_dump"):
                    usage_dict = last_chunk.usage_metadata.model_dump()
                else:
                    usage_dict = {
                        "prompt_token_count": getattr(
                            last_chunk.usage_metadata, "prompt_token_count", 0
                        ),
                        "candidates_token_count": getattr(
                            last_chunk.usage_metadata, "candidates_token_count", 0
                        ),
                        "cached_content_token_count": getattr(
                            last_chunk.usage_metadata, "cached_content_token_count", 0
                        ),
                        "thoughts_token_count": getattr(
                            last_chunk.usage_metadata, "thoughts_token_count", 0
                        ),
                        "total_token_count": getattr(
                            last_chunk.usage_metadata, "total_token_count", 0
                        ),
                    }
                # Add usage_metadata to both run.extra AND outputs
                usage_metadata = _create_usage_metadata(usage_dict)
                current_run = run_helpers.get_current_run_tree()
                if current_run:
                    try:
                        meta = current_run.extra.setdefault("metadata", {}).setdefault(
                            "usage_metadata", {}
                        )
                        meta.update(usage_metadata)
                        current_run.patch()
                    except Exception as e:
                        logger.warning(f"Failed to update usage metadata: {e}")
        except Exception as e:
            logger.debug(f"Error extracting metadata from last chunk: {e}")

    # Return minimal structure with usage_metadata in outputs
    return {
        "content": full_text,
        "usage_metadata": usage_metadata,
    }


def _get_wrapper(
    original_generate: Callable,
    name: str,
    tracing_extra: Optional[TracingExtra] = None,
    is_streaming: bool = False,
) -> Callable:
    """Create a wrapper for Gemini's `generate_content` methods."""
    textra = tracing_extra or {}

    @functools.wraps(original_generate)
    def generate(*args, **kwargs):
        # Handle config object before tracing setup
        _convert_config_for_tracing(kwargs)

        decorator = run_helpers.traceable(
            name=name,
            run_type="llm",
            reduce_fn=_reduce_generate_content_chunks if is_streaming else None,
            process_inputs=_process_gemini_inputs,
            process_outputs=(
                _process_generate_content_response if not is_streaming else None
            ),
            _invocation_params_fn=_infer_invocation_params,
            **textra,
        )

        return decorator(original_generate)(*args, **kwargs)

    @functools.wraps(original_generate)
    async def agenerate(*args, **kwargs):
        # Handle config object before tracing setup
        _convert_config_for_tracing(kwargs)

        decorator = run_helpers.traceable(
            name=name,
            run_type="llm",
            reduce_fn=_reduce_generate_content_chunks if is_streaming else None,
            process_inputs=_process_gemini_inputs,
            process_outputs=(
                _process_generate_content_response if not is_streaming else None
            ),
            _invocation_params_fn=_infer_invocation_params,
            **textra,
        )

        return await decorator(original_generate)(*args, **kwargs)

    return agenerate if run_helpers.is_async(original_generate) else generate


class TracingExtra(TypedDict, total=False):
    metadata: Optional[Mapping[str, Any]]
    tags: Optional[list[str]]
    client: Optional[ls_client.Client]


@warn_beta
def wrap_gemini(
    client: C,
    *,
    tracing_extra: Optional[TracingExtra] = None,
    chat_name: str = "ChatGoogleGenerativeAI",
) -> C:
    """Patch the Google Gen AI client to make it traceable.

    !!! warning

        **BETA**: This wrapper is in beta.

    Supports:
        - `generate_content` and `generate_content_stream` methods
        - Sync and async clients
        - Streaming and non-streaming responses
        - Tool/function calling with proper UI rendering
        - Multimodal inputs (text + images)
        - Image generation with `inline_data` support
        - Token usage tracking including reasoning tokens

    Args:
        client: The Google Gen AI client to patch.
        tracing_extra: Extra tracing information.
        chat_name: The run name for the chat endpoint.

    Returns:
        The patched client.

    Example:
        ```python
        from google import genai
        from google.genai import types
        from langsmith import wrappers

        # Use Google Gen AI client same as you normally would.
        client = wrappers.wrap_gemini(genai.Client(api_key="your-api-key"))

        # Basic text generation:
        response = client.models.generate_content(
            model="gemini-2.5-flash",
            contents="Why is the sky blue?",
        )
        print(response.text)

        # Streaming:
        for chunk in client.models.generate_content_stream(
            model="gemini-2.5-flash",
            contents="Tell me a story",
        ):
            print(chunk.text, end="")

        # Tool/Function calling:
        schedule_meeting_function = {
            "name": "schedule_meeting",
            "description": "Schedules a meeting with specified attendees.",
            "parameters": {
                "type": "object",
                "properties": {
                    "attendees": {"type": "array", "items": {"type": "string"}},
                    "date": {"type": "string"},
                    "time": {"type": "string"},
                    "topic": {"type": "string"},
                },
                "required": ["attendees", "date", "time", "topic"],
            },
        }

        tools = types.Tool(function_declarations=[schedule_meeting_function])
        config = types.GenerateContentConfig(tools=[tools])

        response = client.models.generate_content(
            model="gemini-2.5-flash",
            contents="Schedule a meeting with Bob and Alice tomorrow at 2 PM.",
            config=config,
        )

        # Image generation:
        response = client.models.generate_content(
            model="gemini-2.5-flash-image",
            contents=["Create a picture of a futuristic city"],
        )

        # Save generated image
        from io import BytesIO
        from PIL import Image

        for part in response.candidates[0].content.parts:
            if part.inline_data is not None:
                image = Image.open(BytesIO(part.inline_data.data))
                image.save("generated_image.png")
        ```

    !!! version-added "Added in `langsmith` 0.4.33"

        Initial beta release of Google Gemini wrapper.

    """
    tracing_extra = tracing_extra or {}

    # Check if already wrapped to prevent double-wrapping
    if (
        hasattr(client, "models")
        and hasattr(client.models, "generate_content")
        and hasattr(client.models.generate_content, "__wrapped__")
    ):
        raise ValueError(
            "This Google Gen AI client has already been wrapped. "
            "Wrapping a client multiple times is not supported."
        )

    # Wrap synchronous methods
    if hasattr(client, "models") and hasattr(client.models, "generate_content"):
        client.models.generate_content = _get_wrapper(  # type: ignore[method-assign]
            client.models.generate_content,
            chat_name,
            tracing_extra=tracing_extra,
            is_streaming=False,
        )

    if hasattr(client, "models") and hasattr(client.models, "generate_content_stream"):
        client.models.generate_content_stream = _get_wrapper(  # type: ignore[method-assign]
            client.models.generate_content_stream,
            chat_name,
            tracing_extra=tracing_extra,
            is_streaming=True,
        )

    # Wrap async methods (aio namespace)
    if (
        hasattr(client, "aio")
        and hasattr(client.aio, "models")
        and hasattr(client.aio.models, "generate_content")
    ):
        client.aio.models.generate_content = _get_wrapper(  # type: ignore[method-assign]
            client.aio.models.generate_content,
            chat_name,
            tracing_extra=tracing_extra,
            is_streaming=False,
        )

    if (
        hasattr(client, "aio")
        and hasattr(client.aio, "models")
        and hasattr(client.aio.models, "generate_content_stream")
    ):
        client.aio.models.generate_content_stream = _get_wrapper(  # type: ignore[method-assign]
            client.aio.models.generate_content_stream,
            chat_name,
            tracing_extra=tracing_extra,
            is_streaming=True,
        )

    return client
