"""AI message."""

import json
import logging
import operator
from collections.abc import Sequence
from typing import Any, Literal, cast, overload

from pydantic import model_validator
from typing_extensions import NotRequired, Self, TypedDict, override

from langchain_core.messages import content as types
from langchain_core.messages.base import (
    BaseMessage,
    BaseMessageChunk,
    _extract_reasoning_from_additional_kwargs,
    merge_content,
)
from langchain_core.messages.content import InvalidToolCall
from langchain_core.messages.tool import (
    ToolCall,
    ToolCallChunk,
    default_tool_chunk_parser,
    default_tool_parser,
)
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.usage import _dict_int_op
from langchain_core.utils.utils import LC_AUTO_PREFIX, LC_ID_PREFIX

logger = logging.getLogger(__name__)


class InputTokenDetails(TypedDict, total=False):
    """Breakdown of input token counts.

    Does *not* need to sum to full input token count. Does *not* need to have all keys.

    Example:
        ```python
        {
            "audio": 10,
            "cache_creation": 200,
            "cache_read": 100,
        }
        ```

    May also hold extra provider-specific keys.

    !!! version-added "Added in `langchain-core` 0.3.9"

    """

    audio: int
    """Audio input tokens."""
    cache_creation: int
    """Input tokens that were cached and there was a cache miss.

    Since there was a cache miss, the cache was created from these tokens.
    """
    cache_read: int
    """Input tokens that were cached and there was a cache hit.

    Since there was a cache hit, the tokens were read from the cache. More precisely,
    the model state given these tokens was read from the cache.

    """


class OutputTokenDetails(TypedDict, total=False):
    """Breakdown of output token counts.

    Does *not* need to sum to full output token count. Does *not* need to have all keys.

    Example:
        ```python
        {
            "audio": 10,
            "reasoning": 200,
        }
        ```

    May also hold extra provider-specific keys.

    !!! version-added "Added in `langchain-core` 0.3.9"

    """

    audio: int
    """Audio output tokens."""
    reasoning: int
    """Reasoning output tokens.

    Tokens generated by the model in a chain of thought process (i.e. by OpenAI's o1
    models) that are not returned as part of model output.

    """


class UsageMetadata(TypedDict):
    """Usage metadata for a message, such as token counts.

    This is a standard representation of token usage that is consistent across models.

    Example:
        ```python
        {
            "input_tokens": 350,
            "output_tokens": 240,
            "total_tokens": 590,
            "input_token_details": {
                "audio": 10,
                "cache_creation": 200,
                "cache_read": 100,
            },
            "output_token_details": {
                "audio": 10,
                "reasoning": 200,
            },
        }
        ```

    !!! warning "Behavior changed in `langchain-core` 0.3.9"
        Added `input_token_details` and `output_token_details`.

    !!! note "LangSmith SDK"
        The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
        LangSmith's `UsageMetadata` has additional fields to capture cost information
        used by the LangSmith platform.
    """

    input_tokens: int
    """Count of input (or prompt) tokens. Sum of all input token types."""
    output_tokens: int
    """Count of output (or completion) tokens. Sum of all output token types."""
    total_tokens: int
    """Total token count. Sum of `input_tokens` + `output_tokens`."""
    input_token_details: NotRequired[InputTokenDetails]
    """Breakdown of input token counts.

    Does *not* need to sum to full input token count. Does *not* need to have all keys.
    """
    output_token_details: NotRequired[OutputTokenDetails]
    """Breakdown of output token counts.

    Does *not* need to sum to full output token count. Does *not* need to have all keys.
    """


class AIMessage(BaseMessage):
    """Message from an AI.

    An `AIMessage` is returned from a chat model as a response to a prompt.

    This message represents the output of the model and consists of both
    the raw output as returned by the model and standardized fields
    (e.g., tool calls, usage metadata) added by the LangChain framework.
    """

    tool_calls: list[ToolCall] = []
    """If present, tool calls associated with the message."""
    invalid_tool_calls: list[InvalidToolCall] = []
    """If present, tool calls with parsing errors associated with the message."""
    usage_metadata: UsageMetadata | None = None
    """If present, usage metadata for a message, such as token counts.

    This is a standard representation of token usage that is consistent across models.
    """

    type: Literal["ai"] = "ai"
    """The type of the message (used for deserialization)."""

    @overload
    def __init__(
        self,
        content: str | list[str | dict],
        **kwargs: Any,
    ) -> None: ...

    @overload
    def __init__(
        self,
        content: str | list[str | dict] | None = None,
        content_blocks: list[types.ContentBlock] | None = None,
        **kwargs: Any,
    ) -> None: ...

    def __init__(
        self,
        content: str | list[str | dict] | None = None,
        content_blocks: list[types.ContentBlock] | None = None,
        **kwargs: Any,
    ) -> None:
        """Initialize an `AIMessage`.

        Specify `content` as positional arg or `content_blocks` for typing.

        Args:
            content: The content of the message.
            content_blocks: Typed standard content.
            **kwargs: Additional arguments to pass to the parent class.
        """
        if content_blocks is not None:
            # If there are tool calls in content_blocks, but not in tool_calls, add them
            content_tool_calls = [
                block for block in content_blocks if block.get("type") == "tool_call"
            ]
            if content_tool_calls and "tool_calls" not in kwargs:
                kwargs["tool_calls"] = content_tool_calls

            super().__init__(
                content=cast("str | list[str | dict]", content_blocks),
                **kwargs,
            )
        else:
            super().__init__(content=content, **kwargs)

    @property
    def lc_attributes(self) -> dict:
        """Attributes to be serialized.

        Includes all attributes, even if they are derived from other initialization
        arguments.
        """
        return {
            "tool_calls": self.tool_calls,
            "invalid_tool_calls": self.invalid_tool_calls,
        }

    @property
    def content_blocks(self) -> list[types.ContentBlock]:
        """Return standard, typed `ContentBlock` dicts from the message.

        If the message has a known model provider, use the provider-specific translator
        first before falling back to best-effort parsing. For details, see the property
        on `BaseMessage`.
        """
        if self.response_metadata.get("output_version") == "v1":
            return cast("list[types.ContentBlock]", self.content)

        model_provider = self.response_metadata.get("model_provider")
        if model_provider:
            from langchain_core.messages.block_translators import (  # noqa: PLC0415
                get_translator,
            )

            translator = get_translator(model_provider)
            if translator:
                try:
                    return translator["translate_content"](self)
                except NotImplementedError:
                    pass

        # Otherwise, use best-effort parsing
        blocks = super().content_blocks

        if self.tool_calls:
            # Add from tool_calls if missing from content
            content_tool_call_ids = {
                block.get("id")
                for block in self.content
                if isinstance(block, dict) and block.get("type") == "tool_call"
            }
            for tool_call in self.tool_calls:
                if (id_ := tool_call.get("id")) and id_ not in content_tool_call_ids:
                    tool_call_block: types.ToolCall = {
                        "type": "tool_call",
                        "id": id_,
                        "name": tool_call["name"],
                        "args": tool_call["args"],
                    }
                    if "index" in tool_call:
                        tool_call_block["index"] = tool_call["index"]  # type: ignore[typeddict-item]
                    if "extras" in tool_call:
                        tool_call_block["extras"] = tool_call["extras"]  # type: ignore[typeddict-item]
                    blocks.append(tool_call_block)

        # Best-effort reasoning extraction from additional_kwargs
        # Only add reasoning if not already present
        # Insert before all other blocks to keep reasoning at the start
        has_reasoning = any(block.get("type") == "reasoning" for block in blocks)
        if not has_reasoning and (
            reasoning_block := _extract_reasoning_from_additional_kwargs(self)
        ):
            blocks.insert(0, reasoning_block)

        return blocks

    # TODO: remove this logic if possible, reducing breaking nature of changes
    @model_validator(mode="before")
    @classmethod
    def _backwards_compat_tool_calls(cls, values: dict) -> Any:
        check_additional_kwargs = not any(
            values.get(k)
            for k in ("tool_calls", "invalid_tool_calls", "tool_call_chunks")
        )
        if check_additional_kwargs and (
            raw_tool_calls := values.get("additional_kwargs", {}).get("tool_calls")
        ):
            try:
                if issubclass(cls, AIMessageChunk):
                    values["tool_call_chunks"] = default_tool_chunk_parser(
                        raw_tool_calls
                    )
                else:
                    parsed_tool_calls, parsed_invalid_tool_calls = default_tool_parser(
                        raw_tool_calls
                    )
                    values["tool_calls"] = parsed_tool_calls
                    values["invalid_tool_calls"] = parsed_invalid_tool_calls
            except Exception:
                logger.debug("Failed to parse tool calls", exc_info=True)

        # Ensure "type" is properly set on all tool call-like dicts.
        if tool_calls := values.get("tool_calls"):
            values["tool_calls"] = [
                create_tool_call(
                    **{k: v for k, v in tc.items() if k not in ("type", "extras")}
                )
                for tc in tool_calls
            ]
        if invalid_tool_calls := values.get("invalid_tool_calls"):
            values["invalid_tool_calls"] = [
                create_invalid_tool_call(**{k: v for k, v in tc.items() if k != "type"})
                for tc in invalid_tool_calls
            ]

        if tool_call_chunks := values.get("tool_call_chunks"):
            values["tool_call_chunks"] = [
                create_tool_call_chunk(**{k: v for k, v in tc.items() if k != "type"})
                for tc in tool_call_chunks
            ]

        return values

    @override
    def pretty_repr(self, html: bool = False) -> str:
        """Return a pretty representation of the message for display.

        Args:
            html: Whether to return an HTML-formatted string.

        Returns:
            A pretty representation of the message.

        """
        base = super().pretty_repr(html=html)
        lines = []

        def _format_tool_args(tc: ToolCall | InvalidToolCall) -> list[str]:
            lines = [
                f"  {tc.get('name', 'Tool')} ({tc.get('id')})",
                f" Call ID: {tc.get('id')}",
            ]
            if tc.get("error"):
                lines.append(f"  Error: {tc.get('error')}")
            lines.append("  Args:")
            args = tc.get("args")
            if isinstance(args, str):
                lines.append(f"    {args}")
            elif isinstance(args, dict):
                for arg, value in args.items():
                    lines.append(f"    {arg}: {value}")
            return lines

        if self.tool_calls:
            lines.append("Tool Calls:")
            for tc in self.tool_calls:
                lines.extend(_format_tool_args(tc))
        if self.invalid_tool_calls:
            lines.append("Invalid Tool Calls:")
            for itc in self.invalid_tool_calls:
                lines.extend(_format_tool_args(itc))
        return (base.strip() + "\n" + "\n".join(lines)).strip()


class AIMessageChunk(AIMessage, BaseMessageChunk):
    """Message chunk from an AI (yielded when streaming)."""

    # Ignoring mypy re-assignment here since we're overriding the value
    # to make sure that the chunk variant can be discriminated from the
    # non-chunk variant.
    type: Literal["AIMessageChunk"] = "AIMessageChunk"  # type: ignore[assignment]
    """The type of the message (used for deserialization)."""

    tool_call_chunks: list[ToolCallChunk] = []
    """If provided, tool call chunks associated with the message."""

    chunk_position: Literal["last"] | None = None
    """Optional span represented by an aggregated `AIMessageChunk`.

    If a chunk with `chunk_position="last"` is aggregated into a stream,
    `tool_call_chunks` in message content will be parsed into `tool_calls`.
    """

    @property
    def lc_attributes(self) -> dict:
        """Attributes to be serialized, even if they are derived from other initialization args."""  # noqa: E501
        return {
            "tool_calls": self.tool_calls,
            "invalid_tool_calls": self.invalid_tool_calls,
        }

    @property
    def content_blocks(self) -> list[types.ContentBlock]:
        """Return standard, typed `ContentBlock` dicts from the message."""
        if self.response_metadata.get("output_version") == "v1":
            return cast("list[types.ContentBlock]", self.content)

        model_provider = self.response_metadata.get("model_provider")
        if model_provider:
            from langchain_core.messages.block_translators import (  # noqa: PLC0415
                get_translator,
            )

            translator = get_translator(model_provider)
            if translator:
                try:
                    return translator["translate_content_chunk"](self)
                except NotImplementedError:
                    pass

        # Otherwise, use best-effort parsing
        blocks = super().content_blocks

        if (
            self.tool_call_chunks
            and not self.content
            and self.chunk_position != "last"  # keep tool_calls if aggregated
        ):
            blocks = [
                block
                for block in blocks
                if block["type"] not in ("tool_call", "invalid_tool_call")
            ]
            for tool_call_chunk in self.tool_call_chunks:
                tc: types.ToolCallChunk = {
                    "type": "tool_call_chunk",
                    "id": tool_call_chunk.get("id"),
                    "name": tool_call_chunk.get("name"),
                    "args": tool_call_chunk.get("args"),
                }
                if (idx := tool_call_chunk.get("index")) is not None:
                    tc["index"] = idx
                blocks.append(tc)

        # Best-effort reasoning extraction from additional_kwargs
        # Only add reasoning if not already present
        # Insert before all other blocks to keep reasoning at the start
        has_reasoning = any(block.get("type") == "reasoning" for block in blocks)
        if not has_reasoning and (
            reasoning_block := _extract_reasoning_from_additional_kwargs(self)
        ):
            blocks.insert(0, reasoning_block)

        return blocks

    @model_validator(mode="after")
    def init_tool_calls(self) -> Self:
        """Initialize tool calls from tool call chunks.

        Returns:
            The values with tool calls initialized.

        Raises:
            ValueError: If the tool call chunks are malformed.
        """
        if not self.tool_call_chunks:
            if self.tool_calls:
                self.tool_call_chunks = [
                    create_tool_call_chunk(
                        name=tc["name"],
                        args=json.dumps(tc["args"]),
                        id=tc["id"],
                        index=None,
                    )
                    for tc in self.tool_calls
                ]
            if self.invalid_tool_calls:
                tool_call_chunks = self.tool_call_chunks
                tool_call_chunks.extend(
                    [
                        create_tool_call_chunk(
                            name=tc["name"], args=tc["args"], id=tc["id"], index=None
                        )
                        for tc in self.invalid_tool_calls
                    ]
                )
                self.tool_call_chunks = tool_call_chunks

            return self
        tool_calls = []
        invalid_tool_calls = []

        def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None:
            invalid_tool_calls.append(
                create_invalid_tool_call(
                    name=chunk["name"],
                    args=chunk["args"],
                    id=chunk["id"],
                    error=None,
                )
            )

        for chunk in self.tool_call_chunks:
            try:
                args_ = parse_partial_json(chunk["args"]) if chunk["args"] else {}
                if isinstance(args_, dict):
                    tool_calls.append(
                        create_tool_call(
                            name=chunk["name"] or "",
                            args=args_,
                            id=chunk["id"],
                        )
                    )
                else:
                    add_chunk_to_invalid_tool_calls(chunk)
            except Exception:
                add_chunk_to_invalid_tool_calls(chunk)
        self.tool_calls = tool_calls
        self.invalid_tool_calls = invalid_tool_calls

        if (
            self.chunk_position == "last"
            and self.tool_call_chunks
            and self.response_metadata.get("output_version") == "v1"
            and isinstance(self.content, list)
        ):
            id_to_tc: dict[str, types.ToolCall] = {
                cast("str", tc.get("id")): {
                    "type": "tool_call",
                    "name": tc["name"],
                    "args": tc["args"],
                    "id": tc.get("id"),
                }
                for tc in self.tool_calls
                if "id" in tc
            }
            for idx, block in enumerate(self.content):
                if (
                    isinstance(block, dict)
                    and block.get("type") == "tool_call_chunk"
                    and (call_id := block.get("id"))
                    and call_id in id_to_tc
                ):
                    self.content[idx] = cast("dict[str, Any]", id_to_tc[call_id])
                    if "extras" in block:
                        # mypy does not account for instance check for dict above
                        self.content[idx]["extras"] = block["extras"]  # type: ignore[index]

        return self

    @model_validator(mode="after")
    def init_server_tool_calls(self) -> Self:
        """Parse `server_tool_call_chunks`."""
        if (
            self.chunk_position == "last"
            and self.response_metadata.get("output_version") == "v1"
            and isinstance(self.content, list)
        ):
            for idx, block in enumerate(self.content):
                if (
                    isinstance(block, dict)
                    and block.get("type")
                    in ("server_tool_call", "server_tool_call_chunk")
                    and (args_str := block.get("args"))
                    and isinstance(args_str, str)
                ):
                    try:
                        args = json.loads(args_str)
                        if isinstance(args, dict):
                            self.content[idx]["type"] = "server_tool_call"  # type: ignore[index]
                            self.content[idx]["args"] = args  # type: ignore[index]
                    except json.JSONDecodeError:
                        pass
        return self

    @overload  # type: ignore[override]  # summing BaseMessages gives ChatPromptTemplate
    def __add__(self, other: "AIMessageChunk") -> "AIMessageChunk": ...

    @overload
    def __add__(self, other: Sequence["AIMessageChunk"]) -> "AIMessageChunk": ...

    @overload
    def __add__(self, other: Any) -> BaseMessageChunk: ...

    @override
    def __add__(self, other: Any) -> BaseMessageChunk:
        if isinstance(other, AIMessageChunk):
            return add_ai_message_chunks(self, other)
        if isinstance(other, (list, tuple)) and all(
            isinstance(o, AIMessageChunk) for o in other
        ):
            return add_ai_message_chunks(self, *other)
        return super().__add__(other)


def add_ai_message_chunks(
    left: AIMessageChunk, *others: AIMessageChunk
) -> AIMessageChunk:
    """Add multiple `AIMessageChunk`s together.

    Args:
        left: The first `AIMessageChunk`.
        *others: Other `AIMessageChunk`s to add.

    Returns:
        The resulting `AIMessageChunk`.

    """
    content = merge_content(left.content, *(o.content for o in others))
    additional_kwargs = merge_dicts(
        left.additional_kwargs, *(o.additional_kwargs for o in others)
    )
    response_metadata = merge_dicts(
        left.response_metadata, *(o.response_metadata for o in others)
    )

    # Merge tool call chunks
    if raw_tool_calls := merge_lists(
        left.tool_call_chunks, *(o.tool_call_chunks for o in others)
    ):
        tool_call_chunks = [
            create_tool_call_chunk(
                name=rtc.get("name"),
                args=rtc.get("args"),
                index=rtc.get("index"),
                id=rtc.get("id"),
            )
            for rtc in raw_tool_calls
        ]
    else:
        tool_call_chunks = []

    # Token usage
    if left.usage_metadata or any(o.usage_metadata is not None for o in others):
        usage_metadata: UsageMetadata | None = left.usage_metadata
        for other in others:
            usage_metadata = add_usage(usage_metadata, other.usage_metadata)
    else:
        usage_metadata = None

    chunk_id = None
    candidates = [left.id] + [o.id for o in others]
    # first pass: pick the first provider-assigned id (non-run-* and non-lc_*)
    for id_ in candidates:
        if (
            id_
            and not id_.startswith(LC_ID_PREFIX)
            and not id_.startswith(LC_AUTO_PREFIX)
        ):
            chunk_id = id_
            break
    else:
        # second pass: prefer lc_run-* IDs over lc_* IDs
        for id_ in candidates:
            if id_ and id_.startswith(LC_ID_PREFIX):
                chunk_id = id_
                break
        else:
            # third pass: take any remaining ID (auto-generated lc_* IDs)
            for id_ in candidates:
                if id_:
                    chunk_id = id_
                    break

    chunk_position: Literal["last"] | None = (
        "last" if any(x.chunk_position == "last" for x in [left, *others]) else None
    )

    return left.__class__(
        content=content,
        additional_kwargs=additional_kwargs,
        tool_call_chunks=tool_call_chunks,
        response_metadata=response_metadata,
        usage_metadata=usage_metadata,
        id=chunk_id,
        chunk_position=chunk_position,
    )


def add_usage(left: UsageMetadata | None, right: UsageMetadata | None) -> UsageMetadata:
    """Recursively add two UsageMetadata objects.

    Example:
        ```python
        from langchain_core.messages.ai import add_usage

        left = UsageMetadata(
            input_tokens=5,
            output_tokens=0,
            total_tokens=5,
            input_token_details=InputTokenDetails(cache_read=3),
        )
        right = UsageMetadata(
            input_tokens=0,
            output_tokens=10,
            total_tokens=10,
            output_token_details=OutputTokenDetails(reasoning=4),
        )

        add_usage(left, right)
        ```

        results in

        ```python
        UsageMetadata(
            input_tokens=5,
            output_tokens=10,
            total_tokens=15,
            input_token_details=InputTokenDetails(cache_read=3),
            output_token_details=OutputTokenDetails(reasoning=4),
        )
        ```
    Args:
        left: The first `UsageMetadata` object.
        right: The second `UsageMetadata` object.

    Returns:
        The sum of the two `UsageMetadata` objects.

    """
    if not (left or right):
        return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
    if not (left and right):
        return cast("UsageMetadata", left or right)

    return UsageMetadata(
        **cast(
            "UsageMetadata",
            _dict_int_op(
                cast("dict", left),
                cast("dict", right),
                operator.add,
            ),
        )
    )


def subtract_usage(
    left: UsageMetadata | None, right: UsageMetadata | None
) -> UsageMetadata:
    """Recursively subtract two `UsageMetadata` objects.

    Token counts cannot be negative so the actual operation is `max(left - right, 0)`.

    Example:
        ```python
        from langchain_core.messages.ai import subtract_usage

        left = UsageMetadata(
            input_tokens=5,
            output_tokens=10,
            total_tokens=15,
            input_token_details=InputTokenDetails(cache_read=4),
        )
        right = UsageMetadata(
            input_tokens=3,
            output_tokens=8,
            total_tokens=11,
            output_token_details=OutputTokenDetails(reasoning=4),
        )

        subtract_usage(left, right)
        ```

        results in

        ```python
        UsageMetadata(
            input_tokens=2,
            output_tokens=2,
            total_tokens=4,
            input_token_details=InputTokenDetails(cache_read=4),
            output_token_details=OutputTokenDetails(reasoning=0),
        )
        ```
    Args:
        left: The first `UsageMetadata` object.
        right: The second `UsageMetadata` object.

    Returns:
        The resulting `UsageMetadata` after subtraction.

    """
    if not (left or right):
        return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
    if not (left and right):
        return cast("UsageMetadata", left or right)

    return UsageMetadata(
        **cast(
            "UsageMetadata",
            _dict_int_op(
                cast("dict", left),
                cast("dict", right),
                (lambda le, ri: max(le - ri, 0)),
            ),
        )
    )
