"""BasePrompt schema definition."""

from __future__ import annotations

import warnings
from abc import ABC
from string import Formatter
from typing import TYPE_CHECKING, Any, Literal

from pydantic import BaseModel, create_model

from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.utils import get_colored_text, mustache
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env

if TYPE_CHECKING:
    from collections.abc import Callable, Sequence

try:
    from jinja2 import meta
    from jinja2.exceptions import SecurityError
    from jinja2.sandbox import SandboxedEnvironment

    class _RestrictedSandboxedEnvironment(SandboxedEnvironment):
        """A more restrictive Jinja2 sandbox that blocks all attribute/method access.

        This sandbox only allows simple variable lookups, no attribute or method access.
        This prevents template injection attacks via methods like parse_raw().
        """

        def is_safe_attribute(self, _obj: Any, _attr: str, _value: Any) -> bool:
            """Block ALL attribute access for security.

            Only allow accessing variables directly from the context dict,
            no attribute access on those objects.

            Args:
                _obj: The object being accessed (unused, always blocked).
                _attr: The attribute name (unused, always blocked).
                _value: The attribute value (unused, always blocked).

            Returns:
                False - all attribute access is blocked.
            """
            # Block all attribute access
            return False

        def is_safe_callable(self, _obj: Any) -> bool:
            """Block all method calls for security.

            Args:
                _obj: The object being checked (unused, always blocked).

            Returns:
                False - all callables are blocked.
            """
            return False

        def getattr(self, obj: Any, attribute: str) -> Any:
            """Override getattr to block all attribute access.

            Args:
                obj: The object.
                attribute: The attribute name.

            Returns:
                Never returns.

            Raises:
                SecurityError: Always, to block attribute access.
            """
            msg = (
                f"Access to attributes is not allowed in templates. "
                f"Attempted to access '{attribute}' on {type(obj).__name__}. "
                f"Use only simple variable names like {{{{variable}}}} "
                f"without dots or methods."
            )
            raise SecurityError(msg)

    _HAS_JINJA2 = True
except ImportError:
    _HAS_JINJA2 = False

PromptTemplateFormat = Literal["f-string", "mustache", "jinja2"]


def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
    """Format a template using jinja2.

    *Security warning*:
        As of LangChain 0.0.329, this method uses Jinja2's
        SandboxedEnvironment by default. However, this sand-boxing should
        be treated as a best-effort approach rather than a guarantee of security.
        Do not accept jinja2 templates from untrusted sources as they may lead
        to arbitrary Python code execution.

        https://jinja.palletsprojects.com/en/3.1.x/sandbox/

    Args:
        template: The template string.
        **kwargs: The variables to format the template with.

    Returns:
        The formatted string.

    Raises:
        ImportError: If jinja2 is not installed.
    """
    if not _HAS_JINJA2:
        msg = (
            "jinja2 not installed, which is needed to use the jinja2_formatter. "
            "Please install it with `pip install jinja2`."
            "Please be cautious when using jinja2 templates. "
            "Do not expand jinja2 templates using unverified or user-controlled "
            "inputs as that can result in arbitrary Python code execution."
        )
        raise ImportError(msg)

    # Use a restricted sandbox that blocks ALL attribute/method access
    # Only simple variable lookups like {{variable}} are allowed
    # Attribute access like {{variable.attr}} or {{variable.method()}} is blocked
    return _RestrictedSandboxedEnvironment().from_string(template).render(**kwargs)


def validate_jinja2(template: str, input_variables: list[str]) -> None:
    """Validate that the input variables are valid for the template.

    Issues a warning if missing or extra variables are found.

    Args:
        template: The template string.
        input_variables: The input variables.
    """
    input_variables_set = set(input_variables)
    valid_variables = _get_jinja2_variables_from_template(template)
    missing_variables = valid_variables - input_variables_set
    extra_variables = input_variables_set - valid_variables

    warning_message = ""
    if missing_variables:
        warning_message += f"Missing variables: {missing_variables} "

    if extra_variables:
        warning_message += f"Extra variables: {extra_variables}"

    if warning_message:
        warnings.warn(warning_message.strip(), stacklevel=7)


def _get_jinja2_variables_from_template(template: str) -> set[str]:
    if not _HAS_JINJA2:
        msg = (
            "jinja2 not installed, which is needed to use the jinja2_formatter. "
            "Please install it with `pip install jinja2`."
        )
        raise ImportError(msg)
    env = _RestrictedSandboxedEnvironment()
    ast = env.parse(template)
    return meta.find_undeclared_variables(ast)


def mustache_formatter(template: str, /, **kwargs: Any) -> str:
    """Format a template using mustache.

    Args:
        template: The template string.
        **kwargs: The variables to format the template with.

    Returns:
        The formatted string.
    """
    return mustache.render(template, kwargs)


def mustache_template_vars(
    template: str,
) -> set[str]:
    """Get the top-level variables from a mustache template.

    For nested variables like `{{person.name}}`, only the top-level
    key (`person`) is returned.

    Args:
        template: The template string.

    Returns:
       The top-level variables from the template.
    """
    variables: set[str] = set()
    section_depth = 0
    for type_, key in mustache.tokenize(template):
        if type_ == "end":
            section_depth -= 1
        elif (
            type_ in {"variable", "section", "inverted section", "no escape"}
            and key != "."
            and section_depth == 0
        ):
            variables.add(key.split(".")[0])
        if type_ in {"section", "inverted section"}:
            section_depth += 1
    return variables


Defs = dict[str, "Defs"]


def mustache_schema(template: str) -> type[BaseModel]:
    """Get the variables from a mustache template.

    Args:
        template: The template string.

    Returns:
        The variables from the template as a Pydantic model.
    """
    fields = {}
    prefix: tuple[str, ...] = ()
    section_stack: list[tuple[str, ...]] = []
    for type_, key in mustache.tokenize(template):
        if key == ".":
            continue
        if type_ == "end":
            if section_stack:
                prefix = section_stack.pop()
        elif type_ in {"section", "inverted section"}:
            section_stack.append(prefix)
            prefix += tuple(key.split("."))
            fields[prefix] = False
        elif type_ in {"variable", "no escape"}:
            fields[prefix + tuple(key.split("."))] = True

    for fkey, fval in fields.items():
        fields[fkey] = fval and not any(
            is_subsequence(fkey, k) for k in fields if k != fkey
        )
    defs: Defs = {}  # None means leaf node
    while fields:
        field, is_leaf = fields.popitem()
        current = defs
        for part in field[:-1]:
            current = current.setdefault(part, {})
        current.setdefault(field[-1], "" if is_leaf else {})  # type: ignore[arg-type]
    return _create_model_recursive("PromptInput", defs)


def _create_model_recursive(name: str, defs: Defs) -> type:
    return create_model(  # type: ignore[call-overload]
        name,
        **{
            k: (_create_model_recursive(k, v), None) if v else (type(v), None)
            for k, v in defs.items()
        },
    )


DEFAULT_FORMATTER_MAPPING: dict[str, Callable] = {
    "f-string": formatter.format,
    "mustache": mustache_formatter,
    "jinja2": jinja2_formatter,
}

DEFAULT_VALIDATOR_MAPPING: dict[str, Callable] = {
    "f-string": formatter.validate_input_variables,
    "jinja2": validate_jinja2,
}


def check_valid_template(
    template: str, template_format: str, input_variables: list[str]
) -> None:
    """Check that template string is valid.

    Args:
        template: The template string.
        template_format: The template format. Should be one of "f-string" or "jinja2".
        input_variables: The input variables.

    Raises:
        ValueError: If the template format is not supported.
        ValueError: If the prompt schema is invalid.
    """
    try:
        validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
    except KeyError as exc:
        msg = (
            f"Invalid template format {template_format!r}, should be one of"
            f" {list(DEFAULT_FORMATTER_MAPPING)}."
        )
        raise ValueError(msg) from exc
    try:
        validator_func(template, input_variables)
    except (KeyError, IndexError) as exc:
        msg = (
            "Invalid prompt schema; check for mismatched or missing input parameters"
            f" from {input_variables}."
        )
        raise ValueError(msg) from exc


def get_template_variables(template: str, template_format: str) -> list[str]:
    """Get the variables from the template.

    Args:
        template: The template string.
        template_format: The template format. Should be one of "f-string" or "jinja2".

    Returns:
        The variables from the template.

    Raises:
        ValueError: If the template format is not supported.
    """
    if template_format == "jinja2":
        # Get the variables for the template
        input_variables = _get_jinja2_variables_from_template(template)
    elif template_format == "f-string":
        input_variables = {
            v for _, v, _, _ in Formatter().parse(template) if v is not None
        }
    elif template_format == "mustache":
        input_variables = mustache_template_vars(template)
    else:
        msg = f"Unsupported template format: {template_format}"
        raise ValueError(msg)

    # For f-strings, block attribute access and indexing syntax
    # This prevents template injection attacks via accessing dangerous attributes
    if template_format == "f-string":
        for var in input_variables:
            # Formatter().parse() returns field names with dots/brackets if present
            # e.g., "obj.attr" or "obj[0]" - we need to block these
            if "." in var or "[" in var or "]" in var:
                msg = (
                    f"Invalid variable name {var!r} in f-string template. "
                    f"Variable names cannot contain attribute "
                    f"access (.) or indexing ([])."
                )
                raise ValueError(msg)

            # Block variable names that are all digits (e.g., "0", "100")
            # These are interpreted as positional arguments, not keyword arguments
            if var.isdigit():
                msg = (
                    f"Invalid variable name {var!r} in f-string template. "
                    f"Variable names cannot be all digits as they are interpreted "
                    f"as positional arguments."
                )
                raise ValueError(msg)

    return sorted(input_variables)


class StringPromptTemplate(BasePromptTemplate, ABC):
    """String prompt that exposes the format method, returning a prompt."""

    @classmethod
    def get_lc_namespace(cls) -> list[str]:
        """Get the namespace of the LangChain object.

        Returns:
            `["langchain", "prompts", "base"]`
        """
        return ["langchain", "prompts", "base"]

    def format_prompt(self, **kwargs: Any) -> PromptValue:
        """Format the prompt with the inputs.

        Args:
            **kwargs: Any arguments to be passed to the prompt template.

        Returns:
            A formatted string.
        """
        return StringPromptValue(text=self.format(**kwargs))

    async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
        """Async format the prompt with the inputs.

        Args:
            **kwargs: Any arguments to be passed to the prompt template.

        Returns:
            A formatted string.
        """
        return StringPromptValue(text=await self.aformat(**kwargs))

    def pretty_repr(
        self,
        html: bool = False,  # noqa: FBT001,FBT002
    ) -> str:
        """Get a pretty representation of the prompt.

        Args:
            html: Whether to return an HTML-formatted string.

        Returns:
            A pretty representation of the prompt.
        """
        # TODO: handle partials
        dummy_vars = {
            input_var: "{" + f"{input_var}" + "}" for input_var in self.input_variables
        }
        if html:
            dummy_vars = {
                k: get_colored_text(v, "yellow") for k, v in dummy_vars.items()
            }
        return self.format(**dummy_vars)

    def pretty_print(self) -> None:
        """Print a pretty representation of the prompt."""
        print(self.pretty_repr(html=is_interactive_env()))  # noqa: T201


def is_subsequence(child: Sequence, parent: Sequence) -> bool:
    """Return True if child is subsequence of parent."""
    if len(child) == 0 or len(parent) == 0:
        return False
    if len(parent) < len(child):
        return False
    return all(child[i] == parent[i] for i in range(len(child)))
