from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict, Union

from langsmith.evaluation.evaluator import DynamicRunEvaluator
from langsmith.run_helpers import traceable
from langsmith.schemas import Example, Run

if TYPE_CHECKING:
    from langchain.evaluation.schema import (  # type: ignore[import-not-found]
        StringEvaluator,
    )

    from langsmith.evaluation.evaluator import (  # type: ignore[import-not-found]
        RunEvaluator,
    )


class SingleEvaluatorInput(TypedDict):
    """The input to a `StringEvaluator`."""

    prediction: str
    """The prediction string."""
    reference: Optional[Any]
    """The reference string."""
    input: Optional[str]
    """The input string."""


class LangChainStringEvaluator:
    r"""A class for wrapping a LangChain `StringEvaluator`.

    Requires the `langchain` package to be installed.

    Attributes:
        evaluator (StringEvaluator): The underlying `StringEvaluator` OR the name
            of the evaluator to load.

    Methods:
        `as_run_evaluator() -> RunEvaluator`:
            Convert the `LangChainStringEvaluator` to a `RunEvaluator`.

    Examples:
        !!! example "Creating a simple `LangChainStringEvaluator`"

            ```python
            evaluator = LangChainStringEvaluator("exact_match")
            ```

        !!! example "Converting a `LangChainStringEvaluator` to a `RunEvaluator`"

            ```python
            from langsmith.evaluation import LangChainStringEvaluator
            from langchain_openai import ChatOpenAI
            evaluator = LangChainStringEvaluator(
                "criteria",
                config={
                    "criteria": {
                        "usefulness": "The prediction is useful if"
                        " it is correct and/or asks a useful followup question."
                    },
                    "llm": ChatOpenAI(model="gpt-4o"),
                },
            )
            run_evaluator = evaluator.as_run_evaluator()
            run_evaluator  # doctest: +ELLIPSIS
            <DynamicRunEvaluator ...>
            ```

        !!! example "Customizing the LLM model used by the evaluator"

            ```python
            from langsmith.evaluation import LangChainStringEvaluator
            from langchain_anthropic import ChatAnthropic
            evaluator = LangChainStringEvaluator(
                "criteria",
                config={
                    "criteria": {
                        "usefulness": "The prediction is useful if"
                        " it is correct and/or asks a useful followup question."
                    },
                    "llm": ChatAnthropic(model="claude-3-opus-20240229"),
                },
            )

            run_evaluator = evaluator.as_run_evaluator()
            run_evaluator  # doctest: +ELLIPSIS
            <DynamicRunEvaluator ...>
            ```

        !!! example "Using the `evaluate` API with different evaluators"

            ```python
            def prepare_data(run: Run, example: Example):
                # Convert the evaluation data into the format expected by the evaluator
                # Only required for datasets with multiple inputs/output keys
                return {
                    "prediction": run.outputs["prediction"],
                    "reference": example.outputs["answer"],
                    "input": str(example.inputs),
                }


            import re
            from langchain_anthropic import ChatAnthropic
            import langsmith
            from langsmith.evaluation import LangChainStringEvaluator, evaluate

            criteria_evaluator = LangChainStringEvaluator(
                "criteria",
                config={
                    "criteria": {
                        "usefulness": "The prediction is useful if it is correct"
                        " and/or asks a useful followup question."
                    },
                    "llm": ChatAnthropic(model="claude-3-opus-20240229"),
                },
                prepare_data=prepare_data,
            )

            embedding_evaluator = LangChainStringEvaluator("embedding_distance")
            exact_match_evaluator = LangChainStringEvaluator("exact_match")
            regex_match_evaluator = LangChainStringEvaluator(
                "regex_match", config={"flags": re.IGNORECASE}, prepare_data=prepare_data
            )

            scoring_evaluator = LangChainStringEvaluator(
                "labeled_score_string",
                config={
                    "criteria": {
                        "accuracy": "Score 1: Completely inaccurate\nScore 5: Somewhat accurate\nScore 10: Completely accurate"
                    },
                    "normalize_by": 10,
                    "llm": ChatAnthropic(model="claude-3-opus-20240229"),
                },
                prepare_data=prepare_data,
            )
            string_distance_evaluator = LangChainStringEvaluator(
                "string_distance",
                config={"distance_metric": "levenshtein"},
                prepare_data=prepare_data,
            )

            from langsmith import Client

            client = Client()
            results = evaluate(
                lambda inputs: {"prediction": "foo"},
                data=client.list_examples(dataset_name="Evaluate Examples", limit=1),
                evaluators=[
                    embedding_evaluator,
                    criteria_evaluator,
                    exact_match_evaluator,
                    regex_match_evaluator,
                    scoring_evaluator,
                    string_distance_evaluator,
                ],
            )  # doctest: +ELLIPSIS
            ```
    """  # noqa: E501

    def __init__(
        self,
        evaluator: Union[StringEvaluator, str],
        *,
        config: Optional[dict] = None,
        prepare_data: Optional[
            Callable[[Run, Optional[Example]], SingleEvaluatorInput]
        ] = None,
    ):
        """Initialize a `LangChainStringEvaluator`.

        Args:
            evaluator (StringEvaluator): The underlying `StringEvaluator`.
        """
        from langchain.evaluation.schema import StringEvaluator  # noqa: F811

        if isinstance(evaluator, StringEvaluator):
            self.evaluator = evaluator
        elif isinstance(evaluator, str):
            from langchain.evaluation import (  # type: ignore[import-not-found]
                load_evaluator,  # noqa: F811
            )

            self.evaluator = load_evaluator(evaluator, **(config or {}))  # type: ignore[assignment, arg-type]
        else:
            raise NotImplementedError(f"Unsupported evaluator type: {type(evaluator)}")

        self._prepare_data = prepare_data

    def as_run_evaluator(
        self,
    ) -> RunEvaluator:
        """Convert the `LangChainStringEvaluator` to a `RunEvaluator`.

        This is the object used in the LangSmith `evaluate` API.

        Returns:
            RunEvaluator: The converted `RunEvaluator`.
        """
        input_str = (
            "\n       \"input\": example.inputs['input'],"
            if self.evaluator.requires_input
            else ""
        )
        reference_str = (
            "\n       \"reference\": example.outputs['expected']"
            if self.evaluator.requires_reference
            else ""
        )
        customization_error_str = f"""
def prepare_data(run, example):
    return {{
        "prediction": run.outputs['my_output'],{reference_str}{input_str}
    }}
evaluator = LangChainStringEvaluator(..., prepare_data=prepare_data)
"""

        @traceable
        def prepare_evaluator_inputs(
            run: Run, example: Optional[Example] = None
        ) -> SingleEvaluatorInput:
            if run.outputs and len(run.outputs) > 1:
                raise ValueError(
                    f"Evaluator {self.evaluator} only supports a single prediction "
                    "key. Please ensure that the run has a single output."
                    " Or initialize with a prepare_data:\n"
                    f"{customization_error_str}"
                )
            if (
                self.evaluator.requires_reference
                and example
                and example.outputs
                and len(example.outputs) > 1
            ):
                raise ValueError(
                    f"Evaluator {self.evaluator} nly supports a single reference key. "
                    "Please ensure that the example has a single output."
                    " Or create a custom evaluator yourself:\n"
                    f"{customization_error_str}"
                )
            if (
                self.evaluator.requires_input
                and example
                and example.inputs
                and len(example.inputs) > 1
            ):
                raise ValueError(
                    f"Evaluator {self.evaluator} only supports a single input key. "
                    "Please ensure that the example has a single input."
                    " Or initialize with a prepare_data:\n"
                    f"{customization_error_str}"
                )

            return SingleEvaluatorInput(
                prediction=next(iter(run.outputs.values())),  # type: ignore[union-attr]
                reference=(
                    next(iter(example.outputs.values()))
                    if (
                        self.evaluator.requires_reference
                        and example
                        and example.outputs
                    )
                    else None
                ),
                input=(
                    next(iter(example.inputs.values()))
                    if (self.evaluator.requires_input and example and example.inputs)
                    else None
                ),
            )

        @traceable(name=self.evaluator.evaluation_name)
        def evaluate(run: Run, example: Optional[Example] = None) -> dict:
            eval_inputs = (
                prepare_evaluator_inputs(run, example)
                if self._prepare_data is None
                else self._prepare_data(run, example)
            )
            results = self.evaluator.evaluate_strings(**eval_inputs)
            return {"key": self.evaluator.evaluation_name, **results}

        @traceable(name=self.evaluator.evaluation_name)
        async def aevaluate(run: Run, example: Optional[Example] = None) -> dict:
            eval_inputs = (
                prepare_evaluator_inputs(run, example)
                if self._prepare_data is None
                else self._prepare_data(run, example)
            )
            results = await self.evaluator.aevaluate_strings(**eval_inputs)
            return {"key": self.evaluator.evaluation_name, **results}

        return DynamicRunEvaluator(evaluate, aevaluate)
