"""Base parser for language model outputs."""

from __future__ import annotations

import contextlib
from abc import ABC, abstractmethod
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    TypeVar,
)

from typing_extensions import override

from langchain_core.language_models import LanguageModelOutput
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import run_in_executor

if TYPE_CHECKING:
    from langchain_core.prompt_values import PromptValue

T = TypeVar("T")
OutputParserLike = Runnable[LanguageModelOutput, T]


class BaseLLMOutputParser(ABC, Generic[T]):
    """Abstract base class for parsing the outputs of a model."""

    @abstractmethod
    def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
        """Parse a list of candidate model `Generation` objects into a specific format.

        Args:
            result: A list of `Generation` to be parsed. The `Generation` objects are
                assumed to be different candidate outputs for a single model input.
            partial: Whether to parse the output as a partial result. This is useful
                for parsers that can parse partial results.

        Returns:
            Structured output.
        """

    async def aparse_result(
        self, result: list[Generation], *, partial: bool = False
    ) -> T:
        """Async parse a list of candidate model `Generation` objects into a specific format.

        Args:
            result: A list of `Generation` to be parsed. The Generations are assumed
                to be different candidate outputs for a single model input.
            partial: Whether to parse the output as a partial result. This is useful
                for parsers that can parse partial results.

        Returns:
            Structured output.
        """  # noqa: E501
        return await run_in_executor(None, self.parse_result, result, partial=partial)


class BaseGenerationOutputParser(
    BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T]
):
    """Base class to parse the output of an LLM call."""

    @property
    @override
    def InputType(self) -> Any:
        """Return the input type for the parser."""
        return str | AnyMessage

    @property
    @override
    def OutputType(self) -> type[T]:
        """Return the output type for the parser."""
        # even though mypy complains this isn't valid,
        # it is good enough for pydantic to build the schema from
        return T  # type: ignore[misc]

    @override
    def invoke(
        self,
        input: str | BaseMessage,
        config: RunnableConfig | None = None,
        **kwargs: Any,
    ) -> T:
        if isinstance(input, BaseMessage):
            return self._call_with_config(
                lambda inner_input: self.parse_result(
                    [ChatGeneration(message=inner_input)]
                ),
                input,
                config,
                run_type="parser",
            )
        return self._call_with_config(
            lambda inner_input: self.parse_result([Generation(text=inner_input)]),
            input,
            config,
            run_type="parser",
        )

    @override
    async def ainvoke(
        self,
        input: str | BaseMessage,
        config: RunnableConfig | None = None,
        **kwargs: Any | None,
    ) -> T:
        if isinstance(input, BaseMessage):
            return await self._acall_with_config(
                lambda inner_input: self.aparse_result(
                    [ChatGeneration(message=inner_input)]
                ),
                input,
                config,
                run_type="parser",
            )
        return await self._acall_with_config(
            lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
            input,
            config,
            run_type="parser",
        )


class BaseOutputParser(
    BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T]
):
    """Base class to parse the output of an LLM call.

    Output parsers help structure language model responses.

    Example:
        ```python
        # Implement a simple boolean output parser


        class BooleanOutputParser(BaseOutputParser[bool]):
            true_val: str = "YES"
            false_val: str = "NO"

            def parse(self, text: str) -> bool:
                cleaned_text = text.strip().upper()
                if cleaned_text not in (
                    self.true_val.upper(),
                    self.false_val.upper(),
                ):
                    raise OutputParserException(
                        f"BooleanOutputParser expected output value to either be "
                        f"{self.true_val} or {self.false_val} (case-insensitive). "
                        f"Received {cleaned_text}."
                    )
                return cleaned_text == self.true_val.upper()

            @property
            def _type(self) -> str:
                return "boolean_output_parser"
        ```
    """

    @property
    @override
    def InputType(self) -> Any:
        """Return the input type for the parser."""
        return str | AnyMessage

    @property
    @override
    def OutputType(self) -> type[T]:
        """Return the output type for the parser.

        This property is inferred from the first type argument of the class.

        Raises:
            TypeError: If the class doesn't have an inferable `OutputType`.
        """
        for base in self.__class__.mro():
            if hasattr(base, "__pydantic_generic_metadata__"):
                metadata = base.__pydantic_generic_metadata__
                if "args" in metadata and len(metadata["args"]) > 0:
                    return metadata["args"][0]

        msg = (
            f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
            "Override the OutputType property to specify the output type."
        )
        raise TypeError(msg)

    @override
    def invoke(
        self,
        input: str | BaseMessage,
        config: RunnableConfig | None = None,
        **kwargs: Any,
    ) -> T:
        if isinstance(input, BaseMessage):
            return self._call_with_config(
                lambda inner_input: self.parse_result(
                    [ChatGeneration(message=inner_input)]
                ),
                input,
                config,
                run_type="parser",
            )
        return self._call_with_config(
            lambda inner_input: self.parse_result([Generation(text=inner_input)]),
            input,
            config,
            run_type="parser",
        )

    @override
    async def ainvoke(
        self,
        input: str | BaseMessage,
        config: RunnableConfig | None = None,
        **kwargs: Any | None,
    ) -> T:
        if isinstance(input, BaseMessage):
            return await self._acall_with_config(
                lambda inner_input: self.aparse_result(
                    [ChatGeneration(message=inner_input)]
                ),
                input,
                config,
                run_type="parser",
            )
        return await self._acall_with_config(
            lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
            input,
            config,
            run_type="parser",
        )

    @override
    def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
        """Parse a list of candidate model `Generation` objects into a specific format.

        The return value is parsed from only the first `Generation` in the result, which
            is assumed to be the highest-likelihood `Generation`.

        Args:
            result: A list of `Generation` to be parsed. The `Generation` objects are
                assumed to be different candidate outputs for a single model input.
            partial: Whether to parse the output as a partial result. This is useful
                for parsers that can parse partial results.

        Returns:
            Structured output.
        """
        return self.parse(result[0].text)

    @abstractmethod
    def parse(self, text: str) -> T:
        """Parse a single string model output into some structure.

        Args:
            text: String output of a language model.

        Returns:
            Structured output.
        """

    async def aparse_result(
        self, result: list[Generation], *, partial: bool = False
    ) -> T:
        """Async parse a list of candidate model `Generation` objects into a specific format.

        The return value is parsed from only the first `Generation` in the result, which
            is assumed to be the highest-likelihood `Generation`.

        Args:
            result: A list of `Generation` to be parsed. The `Generation` objects are
                assumed to be different candidate outputs for a single model input.
            partial: Whether to parse the output as a partial result. This is useful
                for parsers that can parse partial results.

        Returns:
            Structured output.
        """  # noqa: E501
        return await run_in_executor(None, self.parse_result, result, partial=partial)

    async def aparse(self, text: str) -> T:
        """Async parse a single string model output into some structure.

        Args:
            text: String output of a language model.

        Returns:
            Structured output.
        """
        return await run_in_executor(None, self.parse, text)

    # TODO: rename 'completion' -> 'text'.
    def parse_with_prompt(
        self,
        completion: str,
        prompt: PromptValue,  # noqa: ARG002
    ) -> Any:
        """Parse the output of an LLM call with the input prompt for context.

        The prompt is largely provided in the event the `OutputParser` wants
        to retry or fix the output in some way, and needs information from
        the prompt to do so.

        Args:
            completion: String output of a language model.
            prompt: Input `PromptValue`.

        Returns:
            Structured output.
        """
        return self.parse(completion)

    def get_format_instructions(self) -> str:
        """Instructions on how the LLM output should be formatted."""
        raise NotImplementedError

    @property
    def _type(self) -> str:
        """Return the output parser type for serialization."""
        msg = (
            f"_type property is not implemented in class {self.__class__.__name__}."
            " This is required for serialization."
        )
        raise NotImplementedError(msg)

    def dict(self, **kwargs: Any) -> dict:
        """Return dictionary representation of output parser."""
        output_parser_dict = super().model_dump(**kwargs)
        with contextlib.suppress(NotImplementedError):
            output_parser_dict["_type"] = self._type
        return output_parser_dict
