import glob
import yaml
import json
import logging
import sys
import os
import re

from datetime import datetime
from typing import List

import click
from rasterio import warp, features, open as ropen, crs, errors  # type:ignore
from rio_stac.stac import bbox_to_geom, get_projection_info  # type:ignore
import numpy
import rasterio
import pystac

from tqdm import tqdm

COL_BBOX = [0.0, 0.0, 0.0, 0.0]
DEFAULT_COL_HREF = "http://hello.fr/collections/soilmoiture"


logging.basicConfig(
    format="%(levelname)s:%(message)s",
    level=os.environ.get("LOGLEVEL", "INFO"),
    stream=sys.stdout,
)
logger = logging.getLogger(__name__)
"""Raster analysis and parsing module."""


EPSG_4326 = crs.CRS.from_epsg(4326)  # pylint: disable=c-extension-no-member


    
class NoSpatialLayerException(Exception):
    """No spatial layer exception."""


class ScalesInputFormatException(Exception):
    """No spatial layer exception."""

class OffsetsInputFormatException(Exception):
    """No spatial layer exception."""

class ProdDateNotFound(Exception):
    """Product datetime not found exception."""

class Info:
    """Grabs raster information."""

    def __init__(self, raster_file):
        """Init Info class.

        Args:
            raster_file: str

        Returns:
            Info class

        """
        self.raster_file = raster_file
        with ropen(self.raster_file) as src:
            bbox = src.bounds
            geom = bbox_to_geom(bbox)
            # Reproject the geometry to "epsg:4326"
            geom = warp.transform_geom(src.crs, EPSG_4326, geom)
            self.bbox = features.bounds(geom)
            self.geom = bbox_to_geom(self.bbox)
            self.meta = src.meta
            self.gsd = src.res[0]
            self.proj_ext_info = get_projection_info(src)
            self.nodata = src.nodata
            self.area_or_point = src.tags().get("AREA_OR_POINT", "").lower()
            self.bands = src.indexes
        self.stats = None

    def band_info(self, band: int):
        """Get band info.

        Args:
            band: band index

        Returns:
            band metadata and band statistics

        """
        if band <= 0:
            raise ValueError('The "band" parameter value starts at 1')

        with ropen(self.raster_file) as src_dst:
            md = {
                "data_type": src_dst.dtypes[band - 1],
                "scale": src_dst.scales[band - 1],
                "offset": src_dst.offsets[band - 1],
            }
            if self.area_or_point:
                md["sampling"] = self.area_or_point

            # If the Nodata is not set we don't forward it.
            if src_dst.nodata is not None:
                if numpy.isnan(src_dst.nodata):
                    md["nodata"] = "nan"
                elif numpy.isposinf(src_dst.nodata):
                    md["nodata"] = "inf"
                elif numpy.isneginf(src_dst.nodata):
                    md["nodata"] = "-inf"
                else:
                    md["nodata"] = src_dst.nodata

            if src_dst.units[band - 1] is not None:
                md["unit"] = src_dst.units[band - 1]

            stats = {}
            try:
                if not self.stats:
                    self.stats = src_dst.stats(approx=True)
                statistics = self.stats[band - 1]
                stats.update(
                    {
                        "mean": statistics.mean,
                        "minimum": statistics.min,
                        "maximum": statistics.max,
                        "stddev": statistics.std,
                    }
                )
            except rasterio.errors.StatisticsError as e:
                logger.warning("Unable to compute relevant statistics: %s", e)

            return md, stats

        

class StacCreator:
    """Creates local STAC catalog"""

    def __init__(self, yaml_conf_file:str):
        try:
            with open(yaml_conf_file, "r", encoding="utf-8") as yamlconf:
                col_details:dict = yaml.full_load(yamlconf)
            mandatories:dict = col_details["mandatory"]

            self.products_base_dir:str = mandatories["PRODUCTS_BASE_DIR"]
            self.out_stac_dir:str = mandatories["STAC_OUTPUT_DIR"]
            self.col_id:str = mandatories["COL_ID"]
            self.col_title:str = mandatories["COL_TITLE"]
            self.col_desc:str = mandatories["COL_DESC"]
            self.col_qlurl:str = mandatories["COL_QUICLOOK_URL"]
            self.asset_prefix:str = mandatories["PROD_NAME_PREFIX"]
            self.asset_extention:str = mandatories["PROD_NAME_EXT"]
            self.asset_key:str = mandatories["ASSET_KEY"]
            self.asset_desc:str = mandatories["ASSET_DESC"]
            self.col_contact:List[dict] = mandatories["CONTACT"]
            self.prod_regex:dict = mandatories["PROD_DATETIME_REGEX"]

        except KeyError as e :
            logger.error("Key missing in json config file.")
            raise e
        
        self.asset_scales: List[dict] | None = None
        self.asset_offsets: List[dict] | None = None
        self.asset_nodata: float | None = None
        self.asset_properties: dict  = {}
        
        if "optionnal" in col_details.keys():
            optionnals:dict = col_details["optionnal"]

            if "ASSET_SCALES" in optionnals.keys():
                self.asset_scales = optionnals["ASSET_SCALES"]
            
            if "ASSET_OFFSETS" in optionnals.keys():
                self.asset_offsets = optionnals["ASSET_OFFSETS"]
            
            if "ASSET_NODATA" in optionnals.keys():
                self.asset_nodata = optionnals["ASSET_NODATA"]

        if not os.path.exists(self.products_base_dir):
            raise FileNotFoundError(f"Directory {self.products_base_dir} not found.")

        if not os.path.exists(self.out_stac_dir):
            os.makedirs(self.out_stac_dir)


    def list_products(self)->List[str]:
        """List products to process according regex.
        """
        filepat = f"{self.asset_prefix}*.{self.asset_extention}"
        searchpat = os.path.join(self.products_base_dir,"**",filepat)

        filelist = glob.glob(searchpat, recursive=True)
        if len(filelist) == 0 :
            raise FileNotFoundError(f"No file found with using regex: {searchpat}")

        logger.debug(f"{len(filelist)} files found")
        return filelist

    def _find_prod_date(self, itemid:str):
        """
        Find product date inside filename according to regex in stac_conf.yaml
        """
        for regex in self.prod_regex:
            result = re.search(regex["re"],itemid)
            if result :
                return regex["dt"], result.group()
            
        raise ProdDateNotFound(f"Product datetime not found for {itemid}. Add a custom date regex in yaml_conf_file.")

    def _parse_scaleoffset_option(self, scaleoffset:List[dict], meta:str):
        """
        Parse scale/offset defined in stac_conf.yaml
        """
        try:
            bandtuple = [0.0] * len(scaleoffset)
            for band in scaleoffset:
                index = int(band["band_index"]) - 1
                value = float(band[meta])
                bandtuple[index] = value
        except Exception as e :
            if meta == "scale":
                raise ScalesInputFormatException(f"Wrong format for ASSET_SCALES in json_conf_file: {e}")
            if meta == "offset":
                raise OffsetsInputFormatException(f"Wrong format for ASSET_OFFSETS in json_conf_file: {e}")
        
        return tuple(bandtuple)
    
    def _get_asset_properties(self, prodfile:str):
        """
        Custom function to add properties to product item
        """
        jsonfile=prodfile.replace(".TIF",".JSON")
        qlfile=prodfile.replace(".TIF",".PNG")

        if not os.path.exists(jsonfile):
            raise Exception(f"JSON file not found for {jsonfile}")
        
        with open(jsonfile, 'r') as mvjson:
                mv_meta = json.load(mvjson)

        properties={
             "description":mv_meta["properties"]["description"],
             "geotag":mv_meta["properties"]["title"],
        }
        assets={
            self.asset_key: pystac.Asset(href=prodfile,
                               description="Sentinel derived soil moisture at plot scale : MV = 0.2 * px",
                               media_type=pystac.MediaType.COG),
            "QL": pystac.Asset(href=qlfile,
                               description="Colorized image quicklook",
                               media_type=pystac.MediaType.PNG),
        }

        return assets, properties

    def create_item(self, prodfile:str)-> pystac.Item :
        """
        Create stac item from local file
        """
        itemid = os.path.basename(prodfile).split(".")[0]
        dateregex, datestring = self._find_prod_date(itemid)
        logger.debug(f"found date : {datestring}")
        proddate = datetime.strptime(datestring
                                     ,dateregex)
        prodinfo = Info(prodfile)

        # Add scales, offsets, nodata in source file header before creating item
        if self.asset_scales:
            if len(self.asset_scales) != len(prodinfo.bands):
                raise ScalesInputFormatException(
                    f"""ASSET_SCALES in json_conf_file don't match bands count :
                    scales={len(self.asset_scales)} / bands={len(prodinfo.bands)}"""
                )
            logger.debug(
                "Applying scales %s on %s", self.asset_scales, os.path.basename(prodfile)
            )
            
            scaletuple = self._parse_scaleoffset_option(self.asset_scales, "scale")

            with ropen(prodfile, "r+", IGNORE_COG_LAYOUT_BREAK=True) as rast:
                rast.scales = tuple(scaletuple)
        if self.asset_offsets:
            if len(self.asset_offsets) != len(prodinfo.bands):
                raise OffsetsInputFormatException(
                    f"""Offset tuple don't match bands count :
                    offsets={len(self.asset_offsets)} / bands={len(prodinfo.bands)}"""
                )
            logger.debug(
                "Applying offsets  %s on %s", self.asset_offsets, os.path.basename(prodfile)
            )

            offsettuple = self._parse_scaleoffset_option(self.asset_offsets, "offset")

            with ropen(prodfile, "r+", IGNORE_COG_LAYOUT_BREAK=True) as rast:
                rast.offsets = offsettuple
        if self.asset_nodata:
            logger.debug(
                "Applying nodata %s on %s", str(self.asset_nodata), os.path.basename(prodfile)
            )
            with ropen(prodfile, "r+", IGNORE_COG_LAYOUT_BREAK=True) as rast:
                rast.nodata = self.asset_nodata

        # Custom function to get properties from asset
        asset_properties = self._get_asset_properties(prodfile)

        # Creating item
        item = pystac.Item(
            id=itemid,
            geometry=prodinfo.geom,
            bbox=prodinfo.bbox,
            datetime=proddate,
            properties=asset_properties[1],
            assets=asset_properties[0],
        )

        return item

    def create_collection(self):
        """Create an empty STAC collection."""
        logger.info(f"# Create an empty STAC collection.")
        spat_extent = pystac.SpatialExtent([COL_BBOX])
        temp_extent = pystac.TemporalExtent(intervals=[[None, None]])  # type: ignore
        col = pystac.Collection(
            id=self.col_id,
            title=self.col_title,
            extent=pystac.Extent(spat_extent, temp_extent),
            description=self.col_desc,
            href=DEFAULT_COL_HREF,
            providers=[
                pystac.Provider("INRAE"),
            ],
            license="etalab-2.0",
            assets={"thumbnail":pystac.Asset(title="preview",
                                roles=["thumbnail"],
                                href=self.col_qlurl,
                                description="Collection thumbnail",
                                media_type=pystac.MediaType.PNG)
            },
            extra_fields={
                "Contacts":self.col_contact
            },

        )
        return col

    def create_items_and_collection(self):
        """
        Create items and fill collection 
        """
        logger.info(f"# Searching files.")
        prodlist = self.list_products()
        # Create items
        logger.info(f"# Create items.")
        items=[]
        for prodfile in tqdm(prodlist):
            items.append(self.create_item(prodfile))

        
        col = self.create_collection()
        # Attach items to collection
        logger.info(f"# Attach items to collection.")
        for item in tqdm(items):
            col.add_item(item)
            col.make_all_asset_hrefs_absolute()

        return col, items

    def generate_collection(self):
        """Generate and save a STAC collection in {root_dir}/collection.json."""
        logger.info(f"# Generate and save a STAC collection in {self.out_stac_dir}/collection.json.")
        col, _ = self.create_items_and_collection()

        catalog = pystac.Catalog(id=f"{self.col_id}-catalog", 
                                description=f"Catalog: {self.col_title}")
        catalog.add_child(col)
        catalog.normalize_and_save(root_href=self.out_stac_dir, 
                            catalog_type=pystac.CatalogType.SELF_CONTAINED)


@click.command()
@click.argument("yaml_conf_file",
                type=click.Path(exists=True)
                )
def run(
    yaml_conf_file: str
):
    """STAC creation from local raster files.

    YAML_CONF_FILE: json file containing parameters for output collection.
    """
    logger.info(f"Processing json conf : {yaml_conf_file}")
    StacCreator(
        yaml_conf_file
    ).generate_collection()

if __name__ == "__main__":
    run()