from __future__ import annotations

from types import TracebackType
from typing import TYPE_CHECKING, Any, Type, Callable, cast
from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never

import httpx
from pydantic import BaseModel

from anthropic.types.tool_use_block import ToolUseBlock
from anthropic.types.server_tool_use_block import ServerToolUseBlock

from ._types import (
    TextEvent,
    CitationEvent,
    ThinkingEvent,
    InputJsonEvent,
    SignatureEvent,
    MessageStopEvent,
    MessageStreamEvent,
    ContentBlockStopEvent,
)
from ...types import Message, ContentBlock, RawMessageStreamEvent
from ..._utils import consume_sync_iterator, consume_async_iterator
from ..._models import build, construct_type, construct_type_unchecked
from ..._streaming import Stream, AsyncStream


class MessageStream:
    text_stream: Iterator[str]
    """Iterator over just the text deltas in the stream.

    ```py
    for text in stream.text_stream:
        print(text, end="", flush=True)
    print()
    ```
    """

    def __init__(self, raw_stream: Stream[RawMessageStreamEvent]) -> None:
        self._raw_stream = raw_stream
        self.text_stream = self.__stream_text__()
        self._iterator = self.__stream__()
        self.__final_message_snapshot: Message | None = None

    @property
    def response(self) -> httpx.Response:
        return self._raw_stream.response

    @property
    def request_id(self) -> str | None:
        return self.response.headers.get("request-id")  # type: ignore[no-any-return]

    def __next__(self) -> MessageStreamEvent:
        return self._iterator.__next__()

    def __iter__(self) -> Iterator[MessageStreamEvent]:
        for item in self._iterator:
            yield item

    def __enter__(self) -> Self:
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        self.close()

    def close(self) -> None:
        """
        Close the response and release the connection.

        Automatically called if the response body is read to completion.
        """
        self._raw_stream.close()

    def get_final_message(self) -> Message:
        """Waits until the stream has been read to completion and returns
        the accumulated `Message` object.
        """
        self.until_done()
        assert self.__final_message_snapshot is not None
        return self.__final_message_snapshot

    def get_final_text(self) -> str:
        """Returns all `text` content blocks concatenated together.

        > [!NOTE]
        > Currently the API will only respond with a single content block.

        Will raise an error if no `text` content blocks were returned.
        """
        message = self.get_final_message()
        text_blocks: list[str] = []
        for block in message.content:
            if block.type == "text":
                text_blocks.append(block.text)

        if not text_blocks:
            raise RuntimeError(
                f".get_final_text() can only be called when the API returns a `text` content block.\nThe API returned {','.join([b.type for b in message.content])} content block type(s) that you can access by calling get_final_message().content"
            )

        return "".join(text_blocks)

    def until_done(self) -> None:
        """Blocks until the stream has been consumed"""
        consume_sync_iterator(self)

    # properties
    @property
    def current_message_snapshot(self) -> Message:
        assert self.__final_message_snapshot is not None
        return self.__final_message_snapshot

    def __stream__(self) -> Iterator[MessageStreamEvent]:
        for sse_event in self._raw_stream:
            self.__final_message_snapshot = accumulate_event(
                event=sse_event,
                current_snapshot=self.__final_message_snapshot,
            )

            events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot)
            for event in events_to_fire:
                yield event

    def __stream_text__(self) -> Iterator[str]:
        for chunk in self:
            if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
                yield chunk.delta.text


class MessageStreamManager:
    """Wrapper over MessageStream that is returned by `.stream()`.

    ```py
    with client.messages.stream(...) as stream:
        for chunk in stream:
            ...
    ```
    """

    def __init__(
        self,
        api_request: Callable[[], Stream[RawMessageStreamEvent]],
    ) -> None:
        self.__stream: MessageStream | None = None
        self.__api_request = api_request

    def __enter__(self) -> MessageStream:
        raw_stream = self.__api_request()
        self.__stream = MessageStream(raw_stream)
        return self.__stream

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        if self.__stream is not None:
            self.__stream.close()


class AsyncMessageStream:
    text_stream: AsyncIterator[str]
    """Async iterator over just the text deltas in the stream.

    ```py
    async for text in stream.text_stream:
        print(text, end="", flush=True)
    print()
    ```
    """

    def __init__(self, raw_stream: AsyncStream[RawMessageStreamEvent]) -> None:
        self._raw_stream = raw_stream
        self.text_stream = self.__stream_text__()
        self._iterator = self.__stream__()
        self.__final_message_snapshot: Message | None = None

    @property
    def response(self) -> httpx.Response:
        return self._raw_stream.response

    @property
    def request_id(self) -> str | None:
        return self.response.headers.get("request-id")  # type: ignore[no-any-return]

    async def __anext__(self) -> MessageStreamEvent:
        return await self._iterator.__anext__()

    async def __aiter__(self) -> AsyncIterator[MessageStreamEvent]:
        async for item in self._iterator:
            yield item

    async def __aenter__(self) -> Self:
        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        await self.close()

    async def close(self) -> None:
        """
        Close the response and release the connection.

        Automatically called if the response body is read to completion.
        """
        await self._raw_stream.close()

    async def get_final_message(self) -> Message:
        """Waits until the stream has been read to completion and returns
        the accumulated `Message` object.
        """
        await self.until_done()
        assert self.__final_message_snapshot is not None
        return self.__final_message_snapshot

    async def get_final_text(self) -> str:
        """Returns all `text` content blocks concatenated together.

        > [!NOTE]
        > Currently the API will only respond with a single content block.

        Will raise an error if no `text` content blocks were returned.
        """
        message = await self.get_final_message()
        text_blocks: list[str] = []
        for block in message.content:
            if block.type == "text":
                text_blocks.append(block.text)

        if not text_blocks:
            raise RuntimeError(
                f".get_final_text() can only be called when the API returns a `text` content block.\nThe API returned {','.join([b.type for b in message.content])} content block type(s) that you can access by calling get_final_message().content"
            )

        return "".join(text_blocks)

    async def until_done(self) -> None:
        """Waits until the stream has been consumed"""
        await consume_async_iterator(self)

    # properties
    @property
    def current_message_snapshot(self) -> Message:
        assert self.__final_message_snapshot is not None
        return self.__final_message_snapshot

    async def __stream__(self) -> AsyncIterator[MessageStreamEvent]:
        async for sse_event in self._raw_stream:
            self.__final_message_snapshot = accumulate_event(
                event=sse_event,
                current_snapshot=self.__final_message_snapshot,
            )

            events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot)
            for event in events_to_fire:
                yield event

    async def __stream_text__(self) -> AsyncIterator[str]:
        async for chunk in self:
            if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
                yield chunk.delta.text


class AsyncMessageStreamManager:
    """Wrapper over AsyncMessageStream that is returned by `.stream()`
    so that an async context manager can be used without `await`ing the
    original client call.

    ```py
    async with client.messages.stream(...) as stream:
        async for chunk in stream:
            ...
    ```
    """

    def __init__(
        self,
        api_request: Awaitable[AsyncStream[RawMessageStreamEvent]],
    ) -> None:
        self.__stream: AsyncMessageStream | None = None
        self.__api_request = api_request

    async def __aenter__(self) -> AsyncMessageStream:
        raw_stream = await self.__api_request
        self.__stream = AsyncMessageStream(raw_stream)
        return self.__stream

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        if self.__stream is not None:
            await self.__stream.close()


def build_events(
    *,
    event: RawMessageStreamEvent,
    message_snapshot: Message,
) -> list[MessageStreamEvent]:
    events_to_fire: list[MessageStreamEvent] = []

    if event.type == "message_start":
        events_to_fire.append(event)
    elif event.type == "message_delta":
        events_to_fire.append(event)
    elif event.type == "message_stop":
        events_to_fire.append(build(MessageStopEvent, type="message_stop", message=message_snapshot))
    elif event.type == "content_block_start":
        events_to_fire.append(event)
    elif event.type == "content_block_delta":
        events_to_fire.append(event)

        content_block = message_snapshot.content[event.index]
        if event.delta.type == "text_delta":
            if content_block.type == "text":
                events_to_fire.append(
                    build(
                        TextEvent,
                        type="text",
                        text=event.delta.text,
                        snapshot=content_block.text,
                    )
                )
        elif event.delta.type == "input_json_delta":
            if content_block.type == "tool_use":
                events_to_fire.append(
                    build(
                        InputJsonEvent,
                        type="input_json",
                        partial_json=event.delta.partial_json,
                        snapshot=content_block.input,
                    )
                )
        elif event.delta.type == "citations_delta":
            if content_block.type == "text":
                events_to_fire.append(
                    build(
                        CitationEvent,
                        type="citation",
                        citation=event.delta.citation,
                        snapshot=content_block.citations or [],
                    )
                )
        elif event.delta.type == "thinking_delta":
            if content_block.type == "thinking":
                events_to_fire.append(
                    build(
                        ThinkingEvent,
                        type="thinking",
                        thinking=event.delta.thinking,
                        snapshot=content_block.thinking,
                    )
                )
        elif event.delta.type == "signature_delta":
            if content_block.type == "thinking":
                events_to_fire.append(
                    build(
                        SignatureEvent,
                        type="signature",
                        signature=content_block.signature,
                    )
                )
            pass
        else:
            # we only want exhaustive checking for linters, not at runtime
            if TYPE_CHECKING:  # type: ignore[unreachable]
                assert_never(event.delta)
    elif event.type == "content_block_stop":
        content_block = message_snapshot.content[event.index]

        events_to_fire.append(
            build(ContentBlockStopEvent, type="content_block_stop", index=event.index, content_block=content_block),
        )
    else:
        # we only want exhaustive checking for linters, not at runtime
        if TYPE_CHECKING:  # type: ignore[unreachable]
            assert_never(event)

    return events_to_fire


JSON_BUF_PROPERTY = "__json_buf"

TRACKS_TOOL_INPUT = (
    ToolUseBlock,
    ServerToolUseBlock,
)


def accumulate_event(
    *,
    event: RawMessageStreamEvent,
    current_snapshot: Message | None,
) -> Message:
    if not isinstance(cast(Any, event), BaseModel):
        event = cast(  # pyright: ignore[reportUnnecessaryCast]
            RawMessageStreamEvent,
            construct_type_unchecked(
                type_=cast(Type[RawMessageStreamEvent], RawMessageStreamEvent),
                value=event,
            ),
        )
        if not isinstance(cast(Any, event), BaseModel):
            raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}")

    if current_snapshot is None:
        if event.type == "message_start":
            return Message.construct(**cast(Any, event.message.to_dict()))

        raise RuntimeError(f'Unexpected event order, got {event.type} before "message_start"')

    if event.type == "content_block_start":
        # TODO: check index
        current_snapshot.content.append(
            cast(
                ContentBlock,
                construct_type(type_=ContentBlock, value=event.content_block.model_dump()),
            ),
        )
    elif event.type == "content_block_delta":
        content = current_snapshot.content[event.index]
        if event.delta.type == "text_delta":
            if content.type == "text":
                content.text += event.delta.text
        elif event.delta.type == "input_json_delta":
            if isinstance(content, TRACKS_TOOL_INPUT):
                from jiter import from_json

                # we need to keep track of the raw JSON string as well so that we can
                # re-parse it for each delta, for now we just store it as an untyped
                # property on the snapshot
                json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
                json_buf += bytes(event.delta.partial_json, "utf-8")

                if json_buf:
                    content.input = from_json(json_buf, partial_mode=True)

                setattr(content, JSON_BUF_PROPERTY, json_buf)
        elif event.delta.type == "citations_delta":
            if content.type == "text":
                if not content.citations:
                    content.citations = [event.delta.citation]
                else:
                    content.citations.append(event.delta.citation)
        elif event.delta.type == "thinking_delta":
            if content.type == "thinking":
                content.thinking += event.delta.thinking
        elif event.delta.type == "signature_delta":
            if content.type == "thinking":
                content.signature = event.delta.signature
        else:
            # we only want exhaustive checking for linters, not at runtime
            if TYPE_CHECKING:  # type: ignore[unreachable]
                assert_never(event.delta)
    elif event.type == "message_delta":
        current_snapshot.stop_reason = event.delta.stop_reason
        current_snapshot.stop_sequence = event.delta.stop_sequence
        current_snapshot.usage.output_tokens = event.usage.output_tokens

        # Update other usage fields if they exist in the event
        if event.usage.input_tokens is not None:
            current_snapshot.usage.input_tokens = event.usage.input_tokens
        if event.usage.cache_creation_input_tokens is not None:
            current_snapshot.usage.cache_creation_input_tokens = event.usage.cache_creation_input_tokens
        if event.usage.cache_read_input_tokens is not None:
            current_snapshot.usage.cache_read_input_tokens = event.usage.cache_read_input_tokens
        if event.usage.server_tool_use is not None:
            current_snapshot.usage.server_tool_use = event.usage.server_tool_use

    return current_snapshot
