diff --git a/colormaps/gsl_wind_speed.yaml b/colormaps/gsl_wind_speed.yaml new file mode 100644 index 0000000..b0f9c99 --- /dev/null +++ b/colormaps/gsl_wind_speed.yaml @@ -0,0 +1,23 @@ +# Colors used for wind speed as defined by NOAA-GSL's pygraf utility +plot_under: False +colors: + - '#fef8fe' + - '#f8d3f9' + - '#f1a5f3' + - '#e074f0' + - '#0045ff' + - '#0099ff' + - '#00ceff' + - '#00e8ff' + - '#00ffe6' + - '#67d300' + - '#7ffa06' + - '#b4ff36' + - '#eaff12' + - '#ffe500' + - '#ffc808' + - '#ff8608' + - '#ff3300' + - '#ff0039' + - '#f704fc' + diff --git a/custom_functions.py b/custom_functions.py index c8ef0b1..f168e95 100644 --- a/custom_functions.py +++ b/custom_functions.py @@ -4,6 +4,7 @@ """ import logging import uxarray as ux +import numpy as np logger = logging.getLogger(__name__) @@ -62,10 +63,38 @@ def vert_min(field: ux.UxDataArray, dim: str = "nVertLevels") -> ux.UxDataArray: return vertmin +def sum_of_magnitudes(field1: ux.UxDataArray, field2: ux.UxDataArray) -> ux.UxDataArray: + """ + Take two vectors (usually wind vectors) and return the sum of the magnitudes + """ + + return np.sqrt(np.square(field1) + np.square(field2)) + +def max_all_times(field: ux.UxDataArray, dim: str = "Time") -> ux.UxDataArray: + """ + Return the maximum value across all input times for a given point. + """ + # Compute differences along Time + result = field.max(dim=dim, keep_attrs=True) + + return result + +def min_all_times(field: ux.UxDataArray, dim: str = "Time") -> ux.UxDataArray: + """ + Return the minimum value across all input times for a given point. + """ + # Compute differences along Time + result = field.min(dim=dim, keep_attrs=True) + + return result + DERIVED_FUNCTIONS = { "diff_prev_timestep": diff_prev_timestep, "sum_fields": sum_fields, "vert_max": vert_max, "vert_min": vert_min, + "sum_of_magnitudes": sum_of_magnitudes, + "max_all_times": max_all_times, + "min_all_times": min_all_times, } diff --git a/default_options.yaml b/default_options.yaml index f6c02e4..9b7ee12 100644 --- a/default_options.yaml +++ b/default_options.yaml @@ -109,6 +109,8 @@ plot: # {fnme} = Name of file (minus extension) being read for plotted data # {date} = The date of plotted data, in %Y-%m-%d format # {time} = The time of plotted data, in %H:%M:%S format + # {maxval} = The maximum value in the plotted data + # {minval} = The minimum value in the plotted data filename: '{var}_{lev}.png' format: null @@ -116,9 +118,9 @@ plot: text: 'Plot of {varln}, level {lev} for MPAS forecast, {date} {time}' fontsize: 8 exists: rename - dpi: 300 - figheight: 4 - figwidth: 8 + dpi: 200 + figheight: 3 + figwidth: 6 # colormap: # Color scheme to use for output plots. Options can either be standard Matplotlib colormaps (reference @@ -130,10 +132,22 @@ plot: # colormap: "viridis" + # pixel_ratio: + # This controls the quantity of pixels to sample in the rasterization process; higher numbers result + # in higher quality plots, though at a cost of plotting speed. + pixel_ratio: 1 + + # polycollection: + # NOT RECOMMENDED + # This is the legacy plotting method that converts the unstructured grid to a set of polygons. + # This can be orders of magnitude slower than the default raster method and so is not + # recommended for large domains; if you need more detail in your plot, it's recommended to + # increase the "pixel_ratio" setting. # periodic_bdy: - # For periodic domains (including global), the plot routines will omit the boundary cells by default. To plot + # For periodic domains (including global), the polycollection routines will omit the boundary cells by default. To plot # all data, including boundaries, set this option to True, but note that it will slow down plotting substantially. # + polycollection: False periodic_bdy: False # vmin, vmax: diff --git a/environment.yml b/environment.yml index c4dd463..21ce0f9 100644 --- a/environment.yml +++ b/environment.yml @@ -8,7 +8,6 @@ dependencies: - numpy=1.26* - matplotlib - netcdf4 - - xarray=2025.9.0 - cartopy - uwtools=2.9* - - uxarray=2025.05* + - uxarray=2025.11* diff --git a/plot_functions.py b/plot_functions.py index 44df44c..654e551 100644 --- a/plot_functions.py +++ b/plot_functions.py @@ -7,6 +7,7 @@ import os import traceback +import numpy as np import cartopy.crs as ccrs logger = logging.getLogger(__name__) @@ -47,6 +48,10 @@ def set_patterns_and_outfile(valid, var, lev, filepath, field, ftime, plotdict): #filename minus extension fnme=os.path.splitext(filename)[0] + # max and min values for plotted field + maxval=float(field.max().compute()) + minval=float(field.min().compute()) + pattern_dict = { "var": var, "lev": lev, @@ -56,7 +61,9 @@ def set_patterns_and_outfile(valid, var, lev, filepath, field, ftime, plotdict): "fnme": fnme, "proj": plotdict["projection"]["projection"], "date": "no_Time_dimension", - "time": "no_Time_dimension" + "time": "no_Time_dimension", + "maxval": f"{maxval:.2f}", + "minval": f"{minval:.2f}", } if field.attrs.get("units"): pattern_dict.update({ @@ -220,3 +227,67 @@ def set_map_projection(confproj) -> ccrs.Projection: raise ValueError(f"Invalid projection {proj} specified; valid options are:\n{valid}") + +def get_data_extent_raster(raster, lon_bounds=(-180, 180), lat_bounds=(-90, 90)): + """ + Computes data extent from image raster for automatic zooming to data domain + + Parameters + ---------- + raster : np.ndarray + 2D raster array with NaNs outside valid region + lon_bounds : tuple(float, float) + Longitude range corresponding to full raster width + lat_bounds : tuple(float, float) + Latitude range corresponding to full raster height + + Returns + ------- + extent : list [lon_min, lon_max, lat_min, lat_max] + """ + valid = ~np.isnan(raster) + if not np.any(valid): + # no data at all + return lon_bounds + lat_bounds + + # pixel indices of valid data + ys, xs = np.where(valid) + + # convert indices to lon/lat using proportional scaling + nrows, ncols = raster.shape + lon_min, lon_max = lon_bounds + lat_min, lat_max = lat_bounds + + x_min = lon_min + (xs.min() / ncols) * (lon_max - lon_min) + x_max = lon_min + (xs.max() / ncols) * (lon_max - lon_min) + y_min = lat_max - (ys.max() / nrows) * (lat_max - lat_min) + y_max = lat_max - (ys.min() / nrows) * (lat_max - lat_min) + + pad_fraction=0.05 + dx = (x_max - x_min) * pad_fraction + dy = (y_max - y_min) * pad_fraction + # y dimension is flipped for some reason + return [x_min - dx, x_max + dx, -y_max - dy, -y_min + dy] + + +def get_data_extent(uxda, pad_fraction=0.05): + """Return (lon_min, lon_max, lat_min, lat_max) in degrees, with buffer.""" + try: + if "n_face" in uxda.dims: + lons = getattr(uxda.uxgrid, "node_lon", None) + lats = getattr(uxda.uxgrid, "node_lat", None) + else: + lons = uxda.lon + lats = uxda.lat + + lon_min = np.nanmin(lons) + lon_max = np.nanmax(lons) + lat_min = np.nanmin(lats) + lat_max = np.nanmax(lats) + + dx = (lon_max - lon_min) * pad_fraction + dy = (lat_max - lat_min) * pad_fraction + + return [lon_min - dx, lon_max + dx, lat_min - dy, lat_max + dy] + except Exception as e: + raise RuntimeError(f"Could not determine lat/lon bounds: {e}") diff --git a/plot_mpas_netcdf.py b/plot_mpas_netcdf.py index f2458f3..08e32c6 100644 --- a/plot_mpas_netcdf.py +++ b/plot_mpas_netcdf.py @@ -17,6 +17,7 @@ proc = psutil.Process(os.getpid()) import matplotlib as mpl +mpl.use("Agg") import matplotlib.pyplot as plt import cartopy.feature as cfeature import cartopy.crs as ccrs @@ -28,7 +29,7 @@ import uwtools.api.config as uwconfig import custom_functions -from plot_functions import set_map_projection, set_patterns_and_outfile +from plot_functions import set_map_projection, set_patterns_and_outfile, get_data_extent logger = logging.getLogger(__name__) @@ -195,12 +196,27 @@ def plotithandler(config_d: dict,uxds: ux.UxDataset,var: str,lev: int,timeint: i if timestring: ftime_dt = datetime.strptime(timestring.strip(), "%Y-%m-%d_%H:%M:%S") + # timeint was set to -1 in setup_args if variable has no time dimension if timeint == -1: - plotit(config_d['dataset']['vars'][var],field,var,lev,config_d['dataset']['files'][0],ftime_dt) + timeint=0 + plotfield=field else: - plotit(config_d['dataset']['vars'][var],field.isel(Time=timeint),var,lev,config_d['dataset']['files'][timeint],ftime_dt) - + plotfield=field.isel(Time=timeint) + try: + plotit(config_d['dataset']['vars'][var],plotfield,var,lev,config_d['dataset']['files'][0],ftime_dt) + except KeyboardInterrupt: + # Simply return on keyboard interrupts + logger.warning("KeyboardInterrupt detected; stopping process...") + return + except Exception as e: + logger.error(f'Could not plot variable {var}, level {lev}') + logger.debug(f"Arguments to plotit():\n{config_d['dataset']['vars'][var]=}\n{plotfield=}\n"\ + f"{var=}\n{lev=}\n{config_d['dataset']['files'][0]=}\n{ftime_dt=}") + logger.error(f"{traceback.print_tb(e.__traceback__)}:") + logger.error(f"{type(e).__name__}:") + logger.error(e) + raise PlotError def plotit(vardict: dict,uxda: ux.UxDataArray,var: str,lev: int,filepath: str,ftime) -> None: """ @@ -236,73 +252,17 @@ def plotit(vardict: dict,uxda: ux.UxDataArray,var: str,lev: int,filepath: str,ft logger.debug(f"{varslice=}") - logger.debug(f"Memory usage:{proc.memory_info().rss/1024**2} MB") - try: - if plotdict["periodic_bdy"]: - logger.info("Creating polycollection with periodic_bdy=True") - logger.info("NOTE: This option can be very slow for large domains") - pc=varslice.to_polycollection(periodic_elements='split') - else: - pc=varslice.to_polycollection() - logger.debug(f"Memory usage:{proc.memory_info().rss/1024**2} MB") - except ValueError as e: - logger.critical(e) - msg=f"Variable {var} may not have standard vertical levels (nVertLevels):\n" - msg+=f"dimensions={varslice.dims}\n" - msg+="Check documentation for how to handle variables with non-standard vertical levels" - raise PlotError(msg) - - pc.set_antialiased(False) - - # Handle color mapping - cmapname=plotdict["colormap"] - if cmapname in plt.colormaps(): - cmap=mpl.colormaps[cmapname] - pc.set_cmap(plotdict["colormap"]) - elif os.path.exists(colorfile:=f"colormaps/{cmapname}.yaml"): - cmap_settings = uwconfig.get_yaml_config(config=colorfile) - #Overwrite additional settings specified in colormap file - logger.info(f"Color map {cmapname} selected; using custom settings from {colorfile}") - for setting in cmap_settings: - if setting == "colors": - # plot:colors is a list of color values for the custom colormap and is handled separately - continue - logger.debug(f"Overwriting config {setting} with custom value {cmap_settings[setting]} from {colorfile}") - if isinstance(plotdict.get(setting),dict): - plotdict[setting]=deep_merge(plotdict[setting],cmap_settings[setting]) - else: - plotdict[setting]=cmap_settings[setting] - if not (colorbins:=plotdict.get("colorbins")): - colorbins=256 - cmap = mpl.colors.LinearSegmentedColormap.from_list(name="custom",colors=cmap_settings["colors"],N=colorbins) - else: - raise ValueError(f"Requested color map {cmapname} is not valid") - - if not plotdict["plot_over"]: - cmap.set_over(alpha=0) - if not plotdict["plot_under"]: - cmap.set_under(alpha=0) - pc.set_cmap(cmap) - # Set up map projection properties logger.debug(plotdict["projection"]) proj=set_map_projection(plotdict["projection"]) + # Create figure and plot axes fig, ax = plt.subplots(1, 1, figsize=(plotdict["figwidth"], plotdict["figheight"]), dpi=plotdict["dpi"], constrained_layout=True, subplot_kw=dict(projection=proj)) - # Check the valid file formats supported for this figure - validfmts=fig.canvas.get_supported_filetypes() - - logger.debug(f"{plotdict['projection']['lonrange']=}\n{plotdict['projection']['latrange']=}") - if None in plotdict["projection"]["lonrange"] or None in plotdict["projection"]["latrange"]: - logger.info('One or more latitude/longitude range values were not set; plotting full projection') - else: - ax.set_extent([plotdict["projection"]["lonrange"][0], plotdict["projection"]["lonrange"][1], plotdict["projection"]["latrange"][0], plotdict["projection"]["latrange"][1]], crs=ccrs.PlateCarree()) - - pc.set_clim(vmin=plotdict["vmin"],vmax=plotdict["vmax"]) + # Create figure and plot axes #Plot political boundaries if requested if plotdict.get("boundaries"): @@ -311,7 +271,6 @@ def plotit(vardict: dict,uxda: ux.UxDataArray,var: str,lev: int,filepath: str,ft # Users can set these values to scalars or lists; if scalar provided, re-format to list with three identical values for setting in ["color", "linewidth", "scale"]: if type(pb[setting]) is not list: - pb[setting]=[pb[setting],pb[setting],pb[setting]] if pb["detail"]==2: ax.add_feature(cfeature.NaturalEarthFeature(category='cultural', @@ -337,21 +296,116 @@ def plotit(vardict: dict,uxda: ux.UxDataArray,var: str,lev: int,filepath: str,ft ax.add_feature(cfeature.NaturalEarthFeature(category='physical',edgecolor=pl["color"],facecolor='none', linewidth=pl["linewidth"], scale=pl["scale"], name='lakes')) - # Create a dict of substitutable patterns to make string substitutions easier, and determine output filename - patterns,outfile,fmt = set_patterns_and_outfile(validfmts,var,lev,filepath,uxda,ftime,plotdict) - pc.set_edgecolor(plotdict['edges']['color']) - pc.set_linewidth(plotdict['edges']['width']) - pc.set_transform(ccrs.PlateCarree()) + # Handle color mapping + cmapname=plotdict["colormap"] + if cmapname in plt.colormaps(): + if bins:=plotdict.get("colorbins"): + cmap=plt.get_cmap(cmapname, bins) + else: + cmap=mpl.colormaps[cmapname] + elif os.path.exists(colorfile:=f"colormaps/{cmapname}.yaml"): + cmap_settings = uwconfig.get_yaml_config(config=colorfile) + #Overwrite additional settings specified in colormap file + logger.info(f"Color map {cmapname} selected; using custom settings from {colorfile}") + for setting in cmap_settings: + if setting == "colors": + # plot:colors is a list of color values for the custom colormap and is handled separately + continue + logger.debug(f"Overwriting config {setting} with custom value {cmap_settings[setting]} from {colorfile}") + if isinstance(plotdict.get(setting),dict): + plotdict[setting]=deep_merge(plotdict[setting],cmap_settings[setting]) + else: + plotdict[setting]=cmap_settings[setting] + if not (colorbins:=plotdict.get("colorbins")): + colorbins=256 + cmap = mpl.colors.LinearSegmentedColormap.from_list(name="custom",colors=cmap_settings["colors"],N=colorbins) + else: + raise ValueError(f"Requested color map {cmapname} is not valid") - logger.debug("Adding collection to plot axes") - if plotdict["projection"]["projection"] != "PlateCarree": - logger.info(f"Interpolating to {plotdict['projection']['projection']} projection; this may take a while...") + if not plotdict["plot_over"]: + cmap.set_over(alpha=0) + if not plotdict["plot_under"]: + cmap.set_under(alpha=0) + + + logger.debug(f"Memory usage:{proc.memory_info().rss/1024**2} MB") + + # Set axes extent based on data extent and/or user settings if None in plotdict["projection"]["lonrange"] or None in plotdict["projection"]["latrange"]: - coll = ax.add_collection(pc, autolim=True) - ax.autoscale() + logger.info('One or more latitude/longitude range values were not set; plotting full projection') + extent=get_data_extent(varslice) + # If auto-calculated extent is greater than globe, reset to full globe + if (extent[1] - extent[0]) > 360: + extent[1]=180 + extent[0]=-180 + if (extent[3] - extent[2]) > 180: + extent[3]=90 + extent[2]=-90 + else: + extent=[plotdict["projection"]["lonrange"][0], plotdict["projection"]["lonrange"][1], plotdict["projection"]["latrange"][0], plotdict["projection"]["latrange"][1]] + logger.debug(f'Domain extent: {extent}') + + ax.set_extent(extent, crs=ccrs.PlateCarree()) + + # Create image with polycollection or raster + if plotdict["polycollection"]: + try: + if plotdict["periodic_bdy"]: + logger.info("Creating polycollection with periodic_bdy=True") + logger.info("NOTE: This option can be very slow for large domains") + pc=varslice.to_polycollection(periodic_elements='split') + else: + pc=varslice.to_polycollection() + logger.debug(f"Memory usage:{proc.memory_info().rss/1024**2} MB") + except ValueError as e: + logger.critical(e) + msg=f"Variable {var} may not have standard vertical levels (nVertLevels):\n" + msg+=f"dimensions={varslice.dims}\n" + msg+="Check documentation for how to handle variables with non-standard vertical levels" + raise PlotError(msg) + + pc.set_antialiased(False) + + pc.set_edgecolor(plotdict['edges']['color']) + pc.set_linewidth(plotdict['edges']['width']) + pc.set_transform(ccrs.Geodetic()) + + logger.debug("Adding collection to plot axes") + if plotdict["projection"]["projection"] != "PlateCarree": + logger.info(f"Interpolating to {plotdict['projection']['projection']} projection; this may take a while...") + if None in plotdict["projection"]["lonrange"] or None in plotdict["projection"]["latrange"]: + img = ax.add_collection(pc, autolim=True) + ax.autoscale() + else: + img = ax.add_collection(pc) else: - coll = ax.add_collection(pc) + try: + raster = varslice.to_raster(ax=ax, pixel_ratio=plotdict['pixel_ratio']) + except ValueError as e: + logger.critical(e) + msg=f"Variable {var} may not have standard vertical levels (nVertLevels):\n" + msg+=f"dimensions={varslice.dims}\n" + msg+="Check documentation for how to handle variables with non-standard vertical levels" + raise PlotError(msg) + + img = ax.imshow( + raster, + cmap=cmap, + origin="lower", + extent=ax.get_xlim() + ax.get_ylim(), + ) + + img.set_clim(vmin=plotdict["vmin"],vmax=plotdict["vmax"]) + img.set_cmap + + # Check the valid file formats supported for this figure + validfmts=fig.canvas.get_supported_filetypes() + + # Check the valid file formats supported for this figure + validfmts=fig.canvas.get_supported_filetypes() + # Create a dict of substitutable patterns to make string substitutions easier, and determine output filename + patterns,outfile,fmt = set_patterns_and_outfile(validfmts,var,lev,filepath,uxda,ftime,plotdict) logger.debug("Configuring plot title") if plottitle:=plotdict["title"].get("text"): @@ -363,7 +417,7 @@ def plotit(vardict: dict,uxda: ux.UxDataArray,var: str,lev: int,filepath: str,ft if plotdict.get("colorbar"): if plotdict.get("colorbar").get("enable"): cb = plotdict["colorbar"] - cbar = plt.colorbar(coll,ax=ax,orientation=cb["orientation"]) + cbar = plt.colorbar(img,ax=ax,orientation=cb["orientation"]) if cb.get("label"): cbar.set_label(cb["label"].format_map(patterns), fontsize=cb["fontsize"]) cbar.ax.tick_params(labelsize=cb["fontsize"]) @@ -542,6 +596,16 @@ def setup_config(config: str, default: str="default_options.yaml") -> dict: raise TypeError("plot:title should be a dictionary, not a string\n"\ "Adjust your config.yaml accordingly. See default_options.yaml for details.") + if not (None in expt_config["plot"]["projection"]["lonrange"] or None in expt_config["plot"]["projection"]["latrange"]): + lon0=expt_config["plot"]["projection"]["lonrange"][0] + lon1=expt_config["plot"]["projection"]["lonrange"][1] + lat0=expt_config["plot"]["projection"]["latrange"][0] + lat1=expt_config["plot"]["projection"]["latrange"][1] + if lon0 >= lon1: + raise ValueError(f"plot:projection:lonrange first value {lon0} >= second value {lon1}") + if lat0 >= lat1: + raise ValueError(f"plot:projection:latrange first value {lat0} >= second value {lat1}") + logger.debug("Expanding references to other variables and Jinja templates") expt_config.dereference() return expt_config @@ -553,6 +617,7 @@ def worker_init(debug=False): if __name__ == "__main__": + scriptstart = time.time() parser = argparse.ArgumentParser( description="Script for plotting a custom field on the native MPAS grid from native NetCDF format files" ) @@ -584,15 +649,19 @@ def worker_init(debug=False): logger.debug(f"Memory usage:{proc.memory_info().rss/1024**2} MB") plotargs=setup_args(expt_config,dataset) - logger.info('Submitting to starmap') logger.debug(f"Memory usage:{proc.memory_info().rss/1024**2} MB") logger.debug(f"{plotargs=}") # Make the plots! if args.procs > 1: logger.info(f"Plotting in parallel with {args.procs} tasks") - # This is needed to avoid some kind of file handle clobbering mumbo-jumbo with netCDF - multiprocessing.set_start_method("spawn") - with multiprocessing.Pool(processes=args.procs,initializer=worker_init,initargs=(args.debug,)) as pool: - pool.starmap(plotithandler, plotargs) + # This is needed to avoid some kind of file handle clobbering mumbo-jumbo with netCDF + multiprocessing.set_start_method("spawn") + with multiprocessing.Pool(processes=args.procs,initializer=worker_init,initargs=(args.debug,)) as pool: + pool.starmap(plotithandler, plotargs) + else: + i=0 + for instance in plotargs: + plotithandler(*plotargs[i]) + i+=1 - logger.info("Done plotting all figures!") + logger.info(f"Done plotting all figures! Total time {time.time()-scriptstart} seconds") diff --git a/setup_conda.sh b/setup_conda.sh index 10cb250..0ef7894 100644 --- a/setup_conda.sh +++ b/setup_conda.sh @@ -30,6 +30,4 @@ fi if [[ ! "$LD_LIBRARY_PATH" =~ "$CONDA_BUILD_DIR" ]]; then export LD_LIBRARY_PATH=${CONDA_BUILD_DIR}/lib:${LD_LIBRARY_PATH} fi - -echo "To activate the mpas_plot environment, run this command:" -echo " conda activate mpas_plot" +conda activate mpas_plot