"""Local STAC creation example.

Author:
    Loïc Lozac'h, INRAE.
"""

# pylint: disable=C0103,W1203,W0718,W0707,E1120,R0801
import glob
import logging
import os
import re
import sys
import zipfile
from datetime import datetime, timedelta
from typing import List

import click
import numpy
import pystac
import yaml
from pystac import Asset, Collection
from pystac.extensions.scientific import CollectionScientificExtension
from rasterio import crs, errors, features, warp  # type:ignore
from rasterio import open as ropen
from rio_stac.stac import bbox_to_geom, get_projection_info  # type:ignore
from teledetection.upload import stac
from tqdm import tqdm  # type:ignore

logging.basicConfig(
    format="%(levelname)s:%(message)s",
    level=os.environ.get("LOGLEVEL", "INFO"),
    stream=sys.stdout,
)
logger = logging.getLogger(__name__)
"""Raster analysis and parsing module."""


EPSG_4326 = crs.CRS.from_epsg(4326)  # pylint: disable=c-extension-no-member


class NoSpatialLayerException(Exception):
    """No spatial layer exception."""


class ScalesInputFormatException(Exception):
    """No spatial layer exception."""


class OffsetsInputFormatException(Exception):
    """No spatial layer exception."""


class ProdDateNotFound(Exception):
    """Product datetime not found exception."""


class Info:
    """Grabs raster information."""

    def __init__(self, raster_file):
        """Init Info class.

        Args:
            raster_file: str

        Returns:
            Info class

        """
        self.raster_file = raster_file
        with ropen(self.raster_file) as src:
            self.bbox = src.bounds
            self.geom = bbox_to_geom(self.bbox)
            # Reproject the geometry to "epsg:4326"
            self.geom_wgs84 = warp.transform_geom(src.crs, EPSG_4326, self.geom)
            self.bbox_wgs84 = features.bounds(self.geom_wgs84)
            self.meta = src.meta
            self.gsd = src.res[0]
            self.proj_ext_info = get_projection_info(src)
            self.nodata = src.nodata
            self.area_or_point = src.tags().get("AREA_OR_POINT", "").lower()
            self.bands = src.indexes
        self.stats = None

    def band_info(self, band: int):
        """Get band info.

        Args:
            band: band index

        Returns:
            band metadata and band statistics

        """
        if band <= 0:
            raise ValueError('The "band" parameter value starts at 1')

        with ropen(self.raster_file) as src_dst:
            md = {
                "data_type": src_dst.dtypes[band - 1],
                "scale": src_dst.scales[band - 1],
                "offset": src_dst.offsets[band - 1],
            }
            if self.area_or_point:
                md["sampling"] = self.area_or_point

            # If the Nodata is not set we don't forward it.
            if src_dst.nodata is not None:
                if numpy.isnan(src_dst.nodata):
                    md["nodata"] = "nan"
                elif numpy.isposinf(src_dst.nodata):
                    md["nodata"] = "inf"
                elif numpy.isneginf(src_dst.nodata):
                    md["nodata"] = "-inf"
                else:
                    md["nodata"] = src_dst.nodata

            if src_dst.units[band - 1] is not None:
                md["unit"] = src_dst.units[band - 1]

            stats = {}
            try:
                if not self.stats:
                    self.stats = src_dst.stats(approx=True)
                statistics = self.stats[band - 1]
                stats.update(
                    {
                        "mean": statistics.mean,
                        "minimum": statistics.min,
                        "maximum": statistics.max,
                        "stddev": statistics.std,
                    }
                )
            except errors.StatisticsError as e:
                logger.warning("Unable to compute relevant statistics: %s", e)

            return md, stats


class StacCreator:
    """Creates local STAC catalog."""

    def __init__(self, yaml_conf_file: str):
        """Initialize StacCreator with yaml conf file."""
        try:
            with open(yaml_conf_file, "r", encoding="utf-8") as yamlconf:
                col_details: dict = yaml.full_load(yamlconf)
            mandatories: dict = col_details["mandatory"]

            self.products_base_dir: str = mandatories["PRODUCTS_BASE_DIR"]
            self.col_id: str = mandatories["COL_ID"]
            self.col_title: str = mandatories["COL_TITLE"]
            self.col_desc: str = mandatories["COL_DESC"]
            self.col_qlurl: str = mandatories["COL_QUICLOOK_URL"]
            self.prod_prefix: str = mandatories["PROD_NAME_PREFIX"]
            self.prod_extension: str = mandatories["PROD_NAME_EXT"]
            self.assets: List[dict] = mandatories["ASSETS"]
            self.col_contact: List[dict] = mandatories["CONTACT"]
            self.col_providers: List[dict] = mandatories["PROVIDERS"]
            self.col_doi: str = mandatories["DOI"]
            self.prod_regex: dict = mandatories["PROD_DATETIME_REGEX"]

        except KeyError as e:
            logger.error("Key missing in json config file.")
            raise e

        self.asset_scales: List[dict] | None = None
        self.asset_offsets: List[dict] | None = None
        self.asset_nodata: float | None = None
        self.asset_properties: dict = {}

        if "optional" in col_details.keys():
            optionals: dict = col_details["optional"]

            if "ASSET_SCALES" in optionals.keys():
                self.asset_scales = optionals["ASSET_SCALES"]

            if "ASSET_OFFSETS" in optionals.keys():
                self.asset_offsets = optionals["ASSET_OFFSETS"]

            if "ASSET_NODATA" in optionals.keys():
                self.asset_nodata = optionals["ASSET_NODATA"]

        if not os.path.exists(self.products_base_dir):
            raise FileNotFoundError(f"Directory {self.products_base_dir} not found.")

    def _extract_tif_in_zip(self, zip_file_path, search_string="_LST.TIF") -> str:
        """
        Dézippe un fichier correspondant au search_string.

        :param zip_file_path: Chemin du fichier ZIP à dézipper
        :param search_string: String de recherche
        """
        if not os.path.exists(zip_file_path):
            raise FileNotFoundError(
                f"Le fichier ZIP spécifié n'existe pas: {zip_file_path}"
            )

        with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
            for f in zip_ref.namelist():
                if f.endswith(search_string):
                    extracted = zip_ref.extract(f, self.products_base_dir)
                    if not os.path.exists(extracted):
                        raise FileNotFoundError(
                            f"Failed to extract search file from {zip_file_path}."
                        )
                    return extracted

        raise FileNotFoundError(f"Failed to extract search file from {zip_file_path}.")

    def _get_item_properties(self, prodfile: str) -> dict:
        """Custom function to add properties to product item."""
        prod_split = os.path.basename(prodfile).split("_")
        prod_mgrs = prod_split[2]
        prod_synth = prod_split[1]
        return {"s2:mgrs": prod_mgrs, "synthesis":prod_synth}

    def _find_prod_date(self, itemid: str):
        """Find product date inside filename according to regex in stac_conf.yaml."""
        for regex in self.prod_regex:
            result = re.search(regex["re"], itemid)
            if result:
                return regex["dt"], result.group()

        raise ProdDateNotFound(
            f"Product datetime not found for {itemid}. Add a custom date regex in yaml_conf_file."
        )
    
    def _find_prod_enddate(self, itemid: str, startdate:str):
        """Find product end date inside filename according to regex in stac_conf.yaml."""
        extract = itemid.replace(startdate,"")
        for regex in self.prod_regex:
            result = re.search(regex["re"], extract)
            if result:
                return regex["dt"], result.group()

        raise ProdDateNotFound(
            f"Product datetime not found for {itemid}. Add a custom date regex in yaml_conf_file."
        )

    def _parse_scaleoffset_option(self, scaleoffset: List[dict], meta: str):
        """Parse scale/offset defined in stac_conf.yaml."""
        try:
            bandtuple = [0.0] * len(scaleoffset)
            for band in scaleoffset:
                index = int(band["band_index"]) - 1
                value = float(band[meta])
                bandtuple[index] = value
        except Exception as e:
            if meta == "scale":
                raise ScalesInputFormatException(
                    f"Wrong format for ASSET_SCALES in json_conf_file: {e}"
                )
            if meta == "offset":
                raise OffsetsInputFormatException(
                    f"Wrong format for ASSET_OFFSETS in json_conf_file: {e}"
                )

        return tuple(bandtuple)

    def create_providers(self):
        """Create a provider."""
        providers = []
        for p in self.col_providers:
            providers.append(
                pystac.Provider(
                    name=p["name"],
                    roles=["producer"],
                    url=p["url"],
                )
            )
        return providers

    def create_asset(
        self,
        filename: str,
        title: str,
        description: str,
        roles: list[str] | None = None,
    ):
        """Create a pystac asset."""
        if not roles:
            media_type = None
        elif "data" in roles:
            media_type = pystac.MediaType.COG
        elif "overview" in roles:
            media_type = pystac.MediaType.PNG
        else:
            media_type = None
        return Asset(
            href=filename,
            title=title,
            description=description,
            roles=roles or ["data"],
            media_type=media_type,
        )

    def create_item(self, prodfile: str) -> pystac.Item:
        """Create stac item from local file."""
        itemid = os.path.basename(prodfile).split(".")[0]
        dateregex, datestring = self._find_prod_date(itemid)
        logger.debug(f"found date : {datestring}")
        proddate = datetime.strptime(datestring, dateregex)
        

        # Custom asset formatting
        image4cog = prodfile
        qlpng = prodfile.replace(".TIF", ".PNG")

        prodinfo = Info(image4cog)

        # Add scales, offsets, nodata in source file header before creating item
        if self.asset_scales:
            if len(self.asset_scales) != len(prodinfo.bands):
                raise ScalesInputFormatException(
                    f"""ASSET_SCALES in json_conf_file don't match bands count :
                    scales={len(self.asset_scales)} / bands={len(prodinfo.bands)}"""
                )
            logger.debug(
                "Applying scales %s on %s",
                self.asset_scales,
                os.path.basename(prodfile),
            )

            scaletuple = self._parse_scaleoffset_option(self.asset_scales, "scale")

            with ropen(image4cog, "r+", IGNORE_COG_LAYOUT_BREAK=True) as rast:
                rast.scales = tuple(scaletuple)
        if self.asset_offsets:
            if len(self.asset_offsets) != len(prodinfo.bands):
                raise OffsetsInputFormatException(
                    f"""Offset tuple don't match bands count :
                    offsets={len(self.asset_offsets)} / bands={len(prodinfo.bands)}"""
                )
            logger.debug(
                "Applying offsets  %s on %s",
                self.asset_offsets,
                os.path.basename(prodfile),
            )

            offsettuple = self._parse_scaleoffset_option(self.asset_offsets, "offset")

            with ropen(image4cog, "r+", IGNORE_COG_LAYOUT_BREAK=True) as rast:
                rast.offsets = offsettuple
        if self.asset_nodata:
            logger.debug(
                "Applying nodata %s on %s",
                str(self.asset_nodata),
                os.path.basename(prodfile),
            )
            with ropen(image4cog, "r+", IGNORE_COG_LAYOUT_BREAK=True) as rast:
                rast.nodata = self.asset_nodata

        # Custom function to get properties from prod
        item_properties = self._get_item_properties(prodfile)
        if item_properties["synthesis"] == "S2MONTHLY":
            titleadd = "Monthly synthesis of"
            end_datetime=(proddate + timedelta(days=31)).replace(day=1)
        else:
            titleadd = "Yearly synthesis of"
            end_datetime=(proddate + timedelta(days=366)).replace(day=1)
        
        assets = {}
        for asset in self.assets:
            if asset["key"] == "SW":
                assets[asset["key"]] = self.create_asset(
                    image4cog, " ".join([titleadd,asset["title"]]), asset["desc"], ["data"]
                )
            if asset["key"] == "QL":
                assets[asset["key"]] = self.create_asset(
                    qlpng, asset["title"], asset["desc"], ["overview"]
                )

        # Creating item
        item = pystac.Item(
            id=itemid,
            geometry=prodinfo.geom_wgs84,
            bbox=prodinfo.bbox_wgs84,
            datetime=proddate,
            start_datetime=proddate,
            end_datetime=end_datetime,
            properties=item_properties,
            assets=assets,
        )

        return item

    def create_collection(self, items: list[pystac.Item]):
        """Create a pystac collection."""
        col = Collection(
            id=self.col_id,
            description=self.col_desc,
            title=self.col_title,
            license="etalab-2.0",
            providers=self.create_providers(),
            extent=pystac.Extent(
                pystac.SpatialExtent(4 * [0]),
                pystac.TemporalExtent(intervals=[[None, None]]),
            ),
            assets={
                "thumbnail": pystac.Asset(
                    title="preview",
                    roles=["thumbnail"],
                    href=self.col_qlurl,
                    description="Collection thumbnail",
                    media_type=pystac.MediaType.PNG,
                )
            },
            extra_fields={"Contacts": self.col_contact},
        )

        # Apply DOI
        sc_ext = CollectionScientificExtension.ext(col, add_if_missing=True)
        sc_ext.apply(doi=self.col_doi)

        # Check assets unicity
        assert len(set(tuple(sorted(item.assets.keys())) for item in items)) == 1, (
            "All items must have the same assets!"
        )

        # Add items to collection
        for item in items:
            col.add_item(item)

        # Update extent
        col.update_extent_from_items()

        return col

    def list_products(self) -> List[str]:
        """List products to process according regex."""
        filepat = f"{self.prod_prefix}*.{self.prod_extension}"
        searchpat = os.path.join(self.products_base_dir, "**", filepat)

        filelist = glob.glob(searchpat, recursive=True)
        if len(filelist) == 0:
            raise FileNotFoundError(f"No file found with using regex: {searchpat}")

        logger.info(f"{len(filelist)} files found")
        return filelist

    def publish_collection(
        self, col: pystac.Collection, out_dir: str, assets_overwrite: bool = True,
        storage_bucket: str = "sm1-gdc-tests"
    ):
        """Save the collection and upload it."""
        col.normalize_hrefs(out_dir)
        col.save(pystac.CatalogType.ABSOLUTE_PUBLISHED)

        uploader = stac.StacUploadTransactionsHandler(storage_bucket=storage_bucket,assets_overwrite=assets_overwrite)
        uploader.load_and_publish(os.path.join(out_dir, "collection.json"))


@click.command()
@click.argument("yaml_conf_file", type=click.Path(exists=True))
@click.argument("out_dir", type=click.Path())
@click.argument("storage_bucket")
def run(yaml_conf_file: str, out_dir: str, storage_bucket: str="sm1-gdc-tests"):
    """STAC creation from local raster files.

    Args:
        yaml_conf_file: json file containing parameters for output collection.
        out_dir: output directory for stac collection creation
        storage_bucket: Targeted Storage Bucket for CDS Upload.
    """
    logger.info(f"Processing yaml conf : {yaml_conf_file}")
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    logger.info("# Initialization")
    sc = StacCreator(yaml_conf_file)

    logger.info("# Search products")
    prod_list = sc.list_products()

    logger.info("# Creating items")
    items = [sc.create_item(prod) for prod in tqdm(prod_list)]

    logger.info("# Creating collection")
    col = sc.create_collection(items)

    logger.info("# Publish collection")
    sc.publish_collection(col=col, out_dir=out_dir, storage_bucket=storage_bucket)


if __name__ == "__main__":
    run()
