import abc
import asyncio
import gc
import tempfile
import typing
from pathlib import Path

import fastapi
import inspect

from source import utils
from source.registry import ModelRegistry
from source.utils.fastapi import UploadFileFix


class BaseModel(abc.ABC):
    """
    Represent a model.
    """

    def __init__(self, registry: ModelRegistry, configuration: dict[str, typing.Any], path: Path):
        # the model registry
        self.registry = registry

        # get the documentation of the model
        self.summary = configuration.get("summary")
        self.description = configuration.get("description")

        # the environment directory of the model
        self.path = path
        # get the parameters of the model
        self.inputs = configuration.get("inputs", {})
        self.parameters = utils.parameters.load(self.inputs)
        # the mimetype of the model responses
        self.output_type: str = configuration.get("output_type", "application/json")
        # get the tags of the model
        self.tags = configuration.get("tags", [])

        # get the selected interface of the model
        interface_name: typing.Optional[str] = configuration.get("interface", None)
        self.interface = (
            self.registry.interface_registry.interface_types[interface_name](self)
            if interface_name is not None else None
        )

        # is the model currently loaded
        self._loaded = False

        # lock to avoid loading and unloading at the same time
        self.load_lock = asyncio.Lock()

    def __repr__(self):
        return f"<{self.__class__.__name__}: {self.name}>"

    @property
    def api_base(self) -> str:
        """
        Base for the API routes
        :return: the base for the API routes
        """

        return f"{self.registry.api_base}/{self.name}"

    @property
    def name(self):
        """
        Get the name of the model
        :return: the name of the model
        """

        return self.path.name

    def get_information(self):
        """
        Get information about the model
        :return: information about the model
        """

        return {
            "name": self.name,
            "summary": self.summary,
            "description": self.description,
            "inputs": self.inputs,
            "output_type": self.output_type,
            "tags": self.tags,
            "interface": self.interface,
        }

    async def load(self) -> None:
        """
        Load the model within the model manager
        """

        async with self.load_lock:
            # if the model is already loaded, skip
            if self._loaded:
                return

            # unload the currently loaded model if any
            if self.registry.current_loaded_model is not None:
                await self.registry.current_loaded_model.unload()

            # load the model depending on the implementation
            await self._load()

            # mark the model as loaded
            self._loaded = True
            # mark the model as the registry loaded model
            self.registry.current_loaded_model = self

    @abc.abstractmethod
    async def _load(self):
        """
        Load the model
        Do not call manually, use `load` instead.
        """

    async def unload(self) -> None:
        """
        Unload the model within the model manager
        """

        async with self.load_lock:
            # if we are not already loaded, stop
            if not self._loaded:
                return

            # unload the model depending on the implementation
            await self._unload()

            # force the garbage collector to clean the memory
            gc.collect()

            # mark the model as unloaded
            self._loaded = False

            # if we are the registry current loaded model, remove this status
            if self.registry.current_loaded_model is self:
                self.registry.current_loaded_model = None

    @abc.abstractmethod
    async def _unload(self):
        """
        Unload the model
        Do not call manually, use `unload` instead.
        """

    async def infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
        """
        Infer our payload through the model within the model manager
        :return: the response of the model
        """

        async with self.registry.infer_lock:
            # ensure that the model is loaded
            await self.load()

            # model specific inference part
            async for chunk in self._infer(**kwargs):
                yield chunk

    @abc.abstractmethod
    async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
        """
        Infer our payload through the model
        :return: the response of the model
        """

    def mount(self, application: fastapi.FastAPI) -> None:
        """
        Add the model to the api
        :param application: the fastapi application
        """

        # mount the interface if selected
        if self.interface is not None:
            self.interface.mount(application)

        # create an endpoint wrapping the inference inside a fastapi call
        # the arguments will be loaded from the configuration files. Use kwargs for the definition
        async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
            # NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
            # NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
            #   fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
            #   curiosity that may change in the future
            kwargs = {
                key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
                for key, value in kwargs.items()
            }

            # return a streaming response around the inference call
            return fastapi.responses.StreamingResponse(
                content=self.infer(**kwargs),
                media_type=self.output_type,
                headers={
                    # if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
                    "content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
                }
            )

        # update the signature of the function to use the configuration parameters
        infer_api.__signature__ = inspect.Signature(parameters=self.parameters)

        # format the description
        description_sections: list[str] = []
        if self.description is not None:
            description_sections.append(f"# Description\n{self.description}")
        if self.interface is not None:
            description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**")

        description: str = "\n".join(description_sections)

        # add the inference endpoint on the API
        application.add_api_route(
            f"{self.api_base}/infer",
            infer_api,
            methods=["POST"],
            tags=self.tags,
            summary=self.summary,
            description=description,
            response_class=fastapi.responses.StreamingResponse,
            responses={
                200: {"content": {self.output_type: {}}}
            },
        )