Source code for fweather.fweather_data_cube

import fsspec
import pyproj
import tempfile
import rasterio
import pandas as pd
from shapely import Point
from tqdm import tqdm
import xarray as xr
from datetime import datetime
from pyproj import Transformer
from shapely.geometry import box
from pystac_client import Client
from shapely.ops import transform

from fweather.fweather_collection_get_list import collection_get_list
from fweather.fweather_core import name_band

fs = fsspec.filesystem('https')

[docs] def data_cube(stac_url, collection, start_date, end_date, tile=None, bbox=None, freq=None, bands=None, geom=None): """ Create a virtual data cubes using SpatioTemporal Asset Catalog (STAC). Args: stac_url (str): URL endpoint of the STAC API to query. collection (str): STAC collection identifier to retrieve data from. start_date (str): Start date for temporal filtering in 'YYYY-MM-DD' format. end_date (str): End date for temporal filtering in 'YYYY-MM-DD' format. bbox (str/list, optional): Bounding box coordinates for spatial filtering. Can be a string of comma-separated values or a list [minx, miny, maxx, maxy]. Defaults to None. bands (list): List of spectral band identifiers to include. Defaults to None. geom (str/dict, optional): GeoJSON geometry for spatial filtering. Defaults to None. Example: >>> prec_merge_cube = data_cube( ... stac_url=stac_url, ... collection="prec_merge_daily-1", ... start_date="2024-01-01", ... end_date="2024-12-31", ... bbox="-47.2797,-17.0725,-45.4779,-15.4485", ... bands=["merge_daily"] ... ) """ stac = Client.open(stac_url) collection=dict( collection=collection, start_date=start_date, end_date=end_date, bbox=bbox, bands=bands ) if collection['collection'] not in ['landsat-2', 'LANDSAT-16D-1', 'S2-16D-2', 'S2_L2A-1', 'samet_daily-1', 'prec_merge_daily-1', 'prec_merge_hourly-1', 'GOES-GL-DSWRF-Daily-1']: return print(f"{collection['collection']} collection not yet supported.") bands_dict = collection_get_list(stac, collection) if (geom): lat, lon = geom[0]['coordinates'] point = Point(lon, lat) else: bbox = tuple(map(float, collection['bbox'].split(','))) try: sample_image_path = bands_dict[bands[0]][0] except: return print(f"{collection['collection']}'s {bands[0]} not found.") if (collection['collection'] == "samet_daily-1" or collection['collection'] == "prec_merge_daily-1" or collection['collection'] == "GOES-GL-DSWRF-Daily-1"): data_proj = pyproj.CRS.from_epsg(4326) else: with rasterio.open(sample_image_path) as src: data_proj = src.crs if (geom): proj_converter = Transformer.from_crs(pyproj.CRS.from_epsg(4326), data_proj, always_xy=True).transform reproj_point = transform(proj_converter, point) else: proj_converter = Transformer.from_crs(pyproj.CRS.from_epsg(4326), data_proj, always_xy=True).transform bbox_polygon = box(*bbox) reproj_bbox = transform(proj_converter, bbox_polygon) list_da = [] if (collection['collection'] == "prec_merge_daily-1"): data_cube = xr.Dataset() for i in range(len(bands)): for image in tqdm(desc='Fetching... ', unit=" scenes", total=len(bands_dict[bands[i]]), iterable=bands_dict[bands[i]]): try: with tempfile.NamedTemporaryFile() as tmp: fs.get(image, tmp.name) ds = xr.open_dataset(tmp.name, engine='cfgrib') ds_dropped = ds.drop_vars("prmsl") del ds_dropped.attrs['GRIB_edition'] del ds_dropped.attrs['GRIB_centre'] del ds_dropped.attrs['GRIB_centreDescription'] del ds_dropped.attrs['GRIB_subCentre'] del ds_dropped.attrs['Conventions'] del ds_dropped.attrs['institution'] del ds_dropped.attrs['history'] ds_dropped = ds_dropped.drop_vars(['valid_time']) ds_dropped = ds_dropped.drop_vars(['surface']) ds_dropped = ds_dropped.drop_vars(['time']) ds_dropped = ds_dropped.drop_vars(['step']) time = image.split("/")[-1].split('.')[0].split("_")[2] dt = datetime.strptime(time, '%Y%m%d') dt = pd.to_datetime(dt) da = ds_dropped.assign_coords(time = dt) da = da.expand_dims(dim="time") list_da.append(da) except: pass data_cube = xr.combine_by_coords(list_da) min_lon, min_lat, max_lon, max_lat = map(float, collection['bbox'].split(',')) min_lon_360 = min_lon + 360 if min_lon < 0 else min_lon max_lon_360 = max_lon + 360 if max_lon < 0 else max_lon clipped_cube = data_cube.sel( latitude=slice(min_lat, max_lat), longitude=slice(min_lon_360, max_lon_360) ) elif (collection['collection'] == "samet_daily-1"): list_da = [] for i in range(len(bands)): for image in tqdm(desc='Fetching... ', unit=" scenes", total=len(bands_dict[bands[i]]), iterable=bands_dict[bands[i]]): try: with fs.open(image) as f: ds = xr.open_dataset(f) if (geom): clipped_ds = ds.sel(lon=lon, lat=lat, method='nearest') else: min_lon, min_lat, max_lon, max_lat = map(float, collection['bbox'].split(',')) clipped_ds = ds.sel( lon=slice(min_lon, max_lon), lat=slice(min_lat, max_lat) ) ds_dropped = clipped_ds.drop_vars("nobs", errors='ignore') ds_dropped.load() list_da.append(ds_dropped) except Exception as e: pass combined_ds = xr.concat(list_da, dim="time") combined_ds.attrs.clear() clipped_cube = combined_ds.sortby("time") elif (collection['collection'] == "GOES-GL-DSWRF-Daily-1"): list_da = [] for i in range(len(bands)): for image in tqdm(desc='Fetching... ', unit=" scenes", total=len(bands_dict[bands[i]]), iterable=bands_dict[bands[i]]): try: with fs.open(image) as f: ds = xr.open_dataset(f) if (geom): clipped_ds = ds.sel(lon=lon, lat=lat, method='nearest') else: min_lon, min_lat, max_lon, max_lat = map(float, collection['bbox'].split(',')) clipped_ds = ds.sel( lon=slice(min_lon, max_lon), lat=slice(min_lat, max_lat) ) # Assign time time = image.split("/")[-1].split(".")[0].split("_")[-1][0:8] dt = datetime.strptime(time, '%Y%m%d') dt = pd.to_datetime(dt) da = clipped_ds.assign_coords(time = dt) da = da.expand_dims(dim="time") da = da * 0.1 da.load() list_da.append(da) except Exception as e: pass combined_ds = xr.concat(list_da, dim="time") combined_ds.attrs.clear() clipped_cube = combined_ds.sortby("time") else: for i in range(len(bands)): for image in bands_dict[bands[i]]: da = xr.open_dataarray(image, engine='rasterio') da = da.astype('int16') try: da = da.rio.clip_box(*reproj_bbox.bounds) image = image.split('/')[-1] if (collection['collection'] == "AMZ1-WFI-L4-SR-1" or "S2-16D-2" or "LANDSAT-16D-1" or "landsat-2"): time = image.split("_")[3] dt = datetime.strptime(time, '%Y%m%d') if (collection['collection'] == "S2_L2A-1"): time = image.split("_")[2].split('T')[0] dt = datetime.strptime(time, '%Y%m%d') else: time = image.split("_")[-2] dt = datetime.strptime(time, '%Y%m%d') dt = pd.to_datetime(dt) da = da.assign_coords(time = dt) da = da.expand_dims(dim="time") list_da.append(da) except: pass if (i==0): data_cube = xr.combine_by_coords(list_da) data_cube = data_cube.rename({'band_data': name_band(collection['collection'], bands[i])}) else: band_data_array = xr.combine_by_coords(list_da) band_data_array = band_data_array.rename({'band_data': name_band(collection['collection'], bands[i])}) clipped_cube = xr.merge([data_cube, band_data_array]) return clipped_cube