# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""Data catalog providers."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Protocol

import datafusion._internal as df_internal

if TYPE_CHECKING:
    import pyarrow as pa

    from datafusion import DataFrame, SessionContext
    from datafusion.context import TableProviderExportable

try:
    from warnings import deprecated  # Python 3.13+
except ImportError:
    from typing_extensions import deprecated  # Python 3.12


__all__ = [
    "Catalog",
    "CatalogList",
    "CatalogProvider",
    "CatalogProviderList",
    "Schema",
    "SchemaProvider",
    "Table",
]


class CatalogList:
    """DataFusion data catalog list."""

    def __init__(self, catalog_list: df_internal.catalog.RawCatalogList) -> None:
        """This constructor is not typically called by the end user."""
        self.catalog_list = catalog_list

    def __repr__(self) -> str:
        """Print a string representation of the catalog list."""
        return self.catalog_list.__repr__()

    def names(self) -> set[str]:
        """This is an alias for `catalog_names`."""
        return self.catalog_names()

    def catalog_names(self) -> set[str]:
        """Returns the list of schemas in this catalog."""
        return self.catalog_list.catalog_names()

    @staticmethod
    def memory_catalog(ctx: SessionContext | None = None) -> CatalogList:
        """Create an in-memory catalog provider list."""
        catalog_list = df_internal.catalog.RawCatalogList.memory_catalog(ctx)
        return CatalogList(catalog_list)

    def catalog(self, name: str = "datafusion") -> Catalog:
        """Returns the catalog with the given ``name`` from this catalog."""
        catalog = self.catalog_list.catalog(name)

        return (
            Catalog(catalog)
            if isinstance(catalog, df_internal.catalog.RawCatalog)
            else catalog
        )

    def register_catalog(
        self,
        name: str,
        catalog: Catalog | CatalogProvider | CatalogProviderExportable,
    ) -> Catalog | None:
        """Register a catalog with this catalog list."""
        if isinstance(catalog, Catalog):
            return self.catalog_list.register_catalog(name, catalog.catalog)
        return self.catalog_list.register_catalog(name, catalog)


class Catalog:
    """DataFusion data catalog."""

    def __init__(self, catalog: df_internal.catalog.RawCatalog) -> None:
        """This constructor is not typically called by the end user."""
        self.catalog = catalog

    def __repr__(self) -> str:
        """Print a string representation of the catalog."""
        return self.catalog.__repr__()

    def names(self) -> set[str]:
        """This is an alias for `schema_names`."""
        return self.schema_names()

    def schema_names(self) -> set[str]:
        """Returns the list of schemas in this catalog."""
        return self.catalog.schema_names()

    @staticmethod
    def memory_catalog(ctx: SessionContext | None = None) -> Catalog:
        """Create an in-memory catalog provider."""
        catalog = df_internal.catalog.RawCatalog.memory_catalog(ctx)
        return Catalog(catalog)

    def schema(self, name: str = "public") -> Schema:
        """Returns the database with the given ``name`` from this catalog."""
        schema = self.catalog.schema(name)

        return (
            Schema(schema)
            if isinstance(schema, df_internal.catalog.RawSchema)
            else schema
        )

    @deprecated("Use `schema` instead.")
    def database(self, name: str = "public") -> Schema:
        """Returns the database with the given ``name`` from this catalog."""
        return self.schema(name)

    def register_schema(
        self,
        name: str,
        schema: Schema | SchemaProvider | SchemaProviderExportable,
    ) -> Schema | None:
        """Register a schema with this catalog."""
        if isinstance(schema, Schema):
            return self.catalog.register_schema(name, schema._raw_schema)
        return self.catalog.register_schema(name, schema)

    def deregister_schema(self, name: str, cascade: bool = True) -> Schema | None:
        """Deregister a schema from this catalog."""
        return self.catalog.deregister_schema(name, cascade)


class Schema:
    """DataFusion Schema."""

    def __init__(self, schema: df_internal.catalog.RawSchema) -> None:
        """This constructor is not typically called by the end user."""
        self._raw_schema = schema

    def __repr__(self) -> str:
        """Print a string representation of the schema."""
        return self._raw_schema.__repr__()

    @staticmethod
    def memory_schema(ctx: SessionContext | None = None) -> Schema:
        """Create an in-memory schema provider."""
        schema = df_internal.catalog.RawSchema.memory_schema(ctx)
        return Schema(schema)

    def names(self) -> set[str]:
        """This is an alias for `table_names`."""
        return self.table_names()

    def table_names(self) -> set[str]:
        """Returns the list of all tables in this schema."""
        return self._raw_schema.table_names

    def table(self, name: str) -> Table:
        """Return the table with the given ``name`` from this schema."""
        return Table(self._raw_schema.table(name))

    def register_table(
        self,
        name: str,
        table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset,
    ) -> None:
        """Register a table in this schema."""
        return self._raw_schema.register_table(name, table)

    def deregister_table(self, name: str) -> None:
        """Deregister a table provider from this schema."""
        return self._raw_schema.deregister_table(name)

    def table_exist(self, name: str) -> bool:
        """Determines if a table exists in this schema."""
        return self._raw_schema.table_exist(name)


@deprecated("Use `Schema` instead.")
class Database(Schema):
    """See `Schema`."""


class Table:
    """A DataFusion table.

    Internally we currently support the following types of tables:

    - Tables created using built-in DataFusion methods, such as
      reading from CSV or Parquet
    - pyarrow datasets
    - DataFusion DataFrames, which will be converted into a view
    - Externally provided tables implemented with the FFI PyCapsule
      interface (advanced)
    """

    __slots__ = ("_inner",)

    def __init__(
        self,
        table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset,
        ctx: SessionContext | None = None,
    ) -> None:
        """Constructor."""
        self._inner = df_internal.catalog.RawTable(table, ctx)

    def __repr__(self) -> str:
        """Print a string representation of the table."""
        return repr(self._inner)

    @staticmethod
    @deprecated("Use Table() constructor instead.")
    def from_dataset(dataset: pa.dataset.Dataset) -> Table:
        """Turn a :mod:`pyarrow.dataset` ``Dataset`` into a :class:`Table`."""
        return Table(dataset)

    @property
    def schema(self) -> pa.Schema:
        """Returns the schema associated with this table."""
        return self._inner.schema

    @property
    def kind(self) -> str:
        """Returns the kind of table."""
        return self._inner.kind


class CatalogProviderList(ABC):
    """Abstract class for defining a Python based Catalog Provider List."""

    @abstractmethod
    def catalog_names(self) -> set[str]:
        """Set of the names of all catalogs in this catalog list."""
        ...

    @abstractmethod
    def catalog(
        self, name: str
    ) -> CatalogProviderExportable | CatalogProvider | Catalog | None:
        """Retrieve a specific catalog from this catalog list."""
        ...

    def register_catalog(  # noqa: B027
        self, name: str, catalog: CatalogProviderExportable | CatalogProvider | Catalog
    ) -> None:
        """Add a catalog to this catalog list.

        This method is optional. If your catalog provides a fixed list of catalogs, you
        do not need to implement this method.
        """


class CatalogProviderListExportable(Protocol):
    """Type hint for object that has __datafusion_catalog_provider_list__ PyCapsule.

    https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProviderList.html
    """

    def __datafusion_catalog_provider_list__(self, session: Any) -> object: ...


class CatalogProvider(ABC):
    """Abstract class for defining a Python based Catalog Provider."""

    @abstractmethod
    def schema_names(self) -> set[str]:
        """Set of the names of all schemas in this catalog."""
        ...

    @abstractmethod
    def schema(self, name: str) -> Schema | None:
        """Retrieve a specific schema from this catalog."""
        ...

    def register_schema(  # noqa: B027
        self, name: str, schema: SchemaProviderExportable | SchemaProvider | Schema
    ) -> None:
        """Add a schema to this catalog.

        This method is optional. If your catalog provides a fixed list of schemas, you
        do not need to implement this method.
        """

    def deregister_schema(self, name: str, cascade: bool) -> None:  # noqa: B027
        """Remove a schema from this catalog.

        This method is optional. If your catalog provides a fixed list of schemas, you
        do not need to implement this method.

        Args:
            name: The name of the schema to remove.
            cascade: If true, deregister the tables within the schema.
        """


class CatalogProviderExportable(Protocol):
    """Type hint for object that has __datafusion_catalog_provider__ PyCapsule.

    https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
    """

    def __datafusion_catalog_provider__(self, session: Any) -> object: ...


class SchemaProvider(ABC):
    """Abstract class for defining a Python based Schema Provider."""

    def owner_name(self) -> str | None:
        """Returns the owner of the schema.

        This is an optional method. The default return is None.
        """
        return None

    @abstractmethod
    def table_names(self) -> set[str]:
        """Set of the names of all tables in this schema."""
        ...

    @abstractmethod
    def table(self, name: str) -> Table | None:
        """Retrieve a specific table from this schema."""
        ...

    def register_table(  # noqa: B027
        self, name: str, table: Table | TableProviderExportable | Any
    ) -> None:
        """Add a table to this schema.

        This method is optional. If your schema provides a fixed list of tables, you do
        not need to implement this method.
        """

    def deregister_table(self, name: str, cascade: bool) -> None:  # noqa: B027
        """Remove a table from this schema.

        This method is optional. If your schema provides a fixed list of tables, you do
        not need to implement this method.
        """

    @abstractmethod
    def table_exist(self, name: str) -> bool:
        """Returns true if the table exists in this schema."""
        ...


class SchemaProviderExportable(Protocol):
    """Type hint for object that has __datafusion_schema_provider__ PyCapsule.

    https://docs.rs/datafusion/latest/datafusion/catalog/trait.SchemaProvider.html
    """

    def __datafusion_schema_provider__(self, session: Any) -> object: ...
