Source code for improver.calibration.reliability_calibration

# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
"""Reliability calibration plugins."""

import operator
import warnings
from typing import Dict, List, Optional, Tuple, Union

import iris
import numpy as np
import scipy
from iris.coords import AuxCoord, DimCoord
from iris.cube import Cube, CubeList
from iris.exceptions import CoordinateNotFoundError
from numpy import ndarray
from numpy.ma.core import MaskedArray

from improver import BasePlugin, PostProcessingPlugin
from improver.calibration.utilities import (
    check_forecast_consistency,
    create_unified_frt_coord,
    filter_non_matching_cubes,
)
from improver.metadata.probabilistic import (
    find_threshold_coordinate,
    probability_is_above_or_below,
)
from improver.metadata.utilities import generate_mandatory_attributes
from improver.utilities.cube_manipulation import (
    MergeCubes,
    collapsed,
    enforce_coordinate_ordering,
    get_dim_coord_names,
)


[docs] class ConstructReliabilityCalibrationTables(BasePlugin): """A plugin for creating and populating reliability calibration tables."""
[docs] def __init__( self, n_probability_bins: int = 5, single_value_lower_limit: bool = False, single_value_upper_limit: bool = False, ) -> None: """ Initialise class for creating reliability calibration tables. These tables include data columns entitled observation_count, sum_of_forecast_probabilities, and forecast_count, defined below. n_probability_bins: The total number of probability bins required in the reliability tables. If single value limits are turned on, these are included in this total. single_value_lower_limit: Mandates that the lowest bin should be single valued, with a small precision tolerance, defined as 1.0E-6. The bin is thus 0 to 1.0E-6. single_value_upper_limit: Mandates that the highest bin should be single valued, with a small precision tolerance, defined as 1.0E-6. The bin is thus (1 - 1.0E-6) to 1. """ self.single_value_tolerance = 1.0e-6 self.probability_bins = self._define_probability_bins( n_probability_bins, single_value_lower_limit, single_value_upper_limit ) self.table_columns = np.array( ["observation_count", "sum_of_forecast_probabilities", "forecast_count"] ) self.expected_table_shape = (len(self.table_columns), n_probability_bins)
def __repr__(self) -> str: """Represent the configured plugin instance as a string.""" bin_values = ", ".join( ["[{:1.2f} --> {:1.2f}]".format(*item) for item in self.probability_bins] ) result = "<ConstructReliabilityCalibrationTables: probability_bins: {}>" return result.format(bin_values)
[docs] def _define_probability_bins( self, n_probability_bins: int, single_value_lower_limit: bool, single_value_upper_limit: bool, ) -> ndarray: """ Define equally sized probability bins for use in a reliability table. The range 0 to 1 is divided into ranges to give n_probability bins. If single_value_lower_limit and / or single_value_upper_limit are True, additional bins corresponding to values of 0 and / or 1 will be created, each with a width defined by self.single_value_tolerance. Args: n_probability_bins: The total number of probability bins desired in the reliability tables. This number includes the extrema bins (equals 0 and equals 1) if single value limits are turned on, in which case the minimum number of bins is 3. single_value_lower_limit: Mandates that the lowest bin should be single valued, with a small precision tolerance, defined as 1.0E-6. The bin is thus 0 to 1.0E-6. single_value_upper_limit: Mandates that the highest bin should be single valued, with a small precision tolerance, defined as 1.0E-6. The bin is thus (1 - 1.0E-6) to 1. Returns: An array of 2-element arrays that contain the bounds of the probability bins. These bounds are non-overlapping, with adjacent bin boundaries spaced at the smallest representable interval. Raises: ValueError: If trying to use both single_value_lower_limit and single_value_upper_limit with 2 or fewer probability bins. """ if single_value_lower_limit and single_value_upper_limit: if n_probability_bins <= 2: msg = ( "Cannot use both single_value_lower_limit and " "single_value_upper_limit with 2 or fewer " "probability bins." ) raise ValueError(msg) n_probability_bins = n_probability_bins - 2 elif single_value_lower_limit or single_value_upper_limit: n_probability_bins = n_probability_bins - 1 bin_lower = np.linspace(0, 1, n_probability_bins + 1, dtype=np.float32) bin_upper = np.nextafter(bin_lower, 0, dtype=np.float32) bin_upper[-1] = 1.0 bins = np.stack([bin_lower[:-1], bin_upper[1:]], 1).astype(np.float32) if single_value_lower_limit: bins[0, 0] = np.nextafter(self.single_value_tolerance, 1, dtype=np.float32) lowest_bin = np.array([0, self.single_value_tolerance], dtype=np.float32) bins = np.vstack([lowest_bin, bins]).astype(np.float32) if single_value_upper_limit: bins[-1, 1] = np.nextafter( 1.0 - self.single_value_tolerance, 0, dtype=np.float32 ) highest_bin = np.array( [1.0 - self.single_value_tolerance, 1], dtype=np.float32 ) bins = np.vstack([bins, highest_bin]).astype(np.float32) return bins
[docs] def _create_probability_bins_coord(self) -> DimCoord: """ Construct a dimension coordinate describing the probability bins of the reliability table. Returns: A dimension coordinate describing probability bins. """ values = np.mean(self.probability_bins, axis=1, dtype=np.float32) probability_bins_coord = iris.coords.DimCoord( values, long_name="probability_bin", units=1, bounds=self.probability_bins ) return probability_bins_coord
[docs] def _create_reliability_table_coords(self) -> Tuple[DimCoord, AuxCoord]: """ Construct coordinates that describe the reliability table rows. These are observation_count, sum_of_forecast_probabilities, and forecast_count. The order used here is the order in which the table data is populated, so these must remain consistent with the _populate_reliability_bins function. Returns: - A numerical index dimension coordinate. - An auxiliary coordinate that assigns names to the index coordinates, where these names correspond to the reliability table rows. """ index_coord = iris.coords.DimCoord( np.arange(len(self.table_columns), dtype=np.int32), long_name="table_row_index", units=1, ) name_coord = iris.coords.AuxCoord( self.table_columns, long_name="table_row_name", units=1 ) return index_coord, name_coord
[docs] @staticmethod def _define_metadata(forecast_slice: Cube) -> Dict[str, str]: """ Define metadata that is specifically required for reliability table cubes, whilst ensuring any mandatory attributes are also populated. Args: forecast_slice: The source cube from which to get pre-existing metadata of use. Returns: A dictionary of attributes that are appropriate for the reliability table cube. """ attributes = generate_mandatory_attributes([forecast_slice]) attributes["title"] = "Reliability calibration data table" return attributes
[docs] def _create_reliability_table_cube( self, forecast: Cube, threshold_coord: DimCoord ) -> Cube: """ Construct a reliability table cube and populate it with the provided data. The returned cube will include a forecast_reference_time coordinate, which will be the maximum range of bounds of the input forecast reference times, with the point value set to the latest of those in the inputs. It will further include the forecast period, threshold coordinate, and spatial coordinates from the forecast cube. Args: forecast: A cube slice across the spatial dimensions of the forecast data. This slice provides the time and threshold values that relate to the reliability_table_data. threshold_coord: The threshold coordinate. Returns: A reliability table cube. """ def _get_coords_and_dims(coord_names: List[str]) -> List[Tuple[DimCoord, int]]: """Obtain the requested coordinates and their dimension index from the forecast slice cube.""" coords_and_dims = [] leading_coords = [probability_bins_coord, reliability_index_coord] for coord_name in coord_names: crd = forecast_slice.coord(coord_name) crd_dim = forecast_slice.coord_dims(crd) crd_dim = crd_dim[0] + len(leading_coords) if crd_dim else () coords_and_dims.append((crd, crd_dim)) return coords_and_dims forecast_slice = next(forecast.slices_over(["time", threshold_coord])) expected_shape = self.expected_table_shape + forecast_slice.shape dummy_data = np.zeros((expected_shape)) diagnostic = find_threshold_coordinate(forecast).name() attributes = self._define_metadata(forecast) # Define reliability table specific coordinates probability_bins_coord = self._create_probability_bins_coord() ( reliability_index_coord, reliability_name_coord, ) = self._create_reliability_table_coords() frt_coord = create_unified_frt_coord(forecast.coord("forecast_reference_time")) # List of required non-spatial coordinates from the forecast non_spatial_coords = ["forecast_period", diagnostic] # Construct a list of coordinates in the desired order aux_coords_and_dims = _get_coords_and_dims(non_spatial_coords) aux_coords_and_dims.append((reliability_name_coord, 0)) spatial_coords = [forecast.coord(axis=dim).name() for dim in ["x", "y"]] spatial_coords_and_dims = _get_coords_and_dims(spatial_coords) try: spot_index_coord = _get_coords_and_dims(["spot_index"]) wmo_id_coord = _get_coords_and_dims(["wmo_id"]) except CoordinateNotFoundError: dim_coords_and_dims = spatial_coords_and_dims else: dim_coords_and_dims = spot_index_coord aux_coords_and_dims.extend(spatial_coords_and_dims + wmo_id_coord) dim_coords_and_dims.append((reliability_index_coord, 0)) dim_coords_and_dims.append((probability_bins_coord, 1)) reliability_cube = iris.cube.Cube( dummy_data, units=1, attributes=attributes, dim_coords_and_dims=dim_coords_and_dims, aux_coords_and_dims=aux_coords_and_dims, ) reliability_cube.add_aux_coord(frt_coord) reliability_cube.rename("reliability_calibration_table") return reliability_cube
[docs] def _populate_reliability_bins( self, forecast: Union[MaskedArray, ndarray], truth: Union[MaskedArray, ndarray] ) -> MaskedArray: """ For a spatial slice at a single validity time and threshold, populate a reliability table using the provided truth. Args: forecast: An array containing data over a spatial slice for a single validity time and threshold. truth: An array containing a thresholded gridded truth at an equivalent validity time to the forecast array. Returns: An array containing reliability table data for a single time and threshold. The leading dimension corresponds to the rows of a calibration table, the second dimension to the number of probability bins, and the trailing dimension(s) are the spatial dimension(s) of the forecast and truth cubes (which are equivalent). """ bin_edges = np.concatenate( [ np.array(self.probability_bins[:, 0]), np.array([self.probability_bins[-1, 1] + self.single_value_tolerance]), ] ).astype(self.probability_bins.dtype) bin_index = np.searchsorted(bin_edges, forecast, side="right") - 1 # nan values have index len(bin_edges) - 1, which is one more than the number of bins. # Therefore, to make put_along_axis work, we also make the first dimension of the new shape # one more than the number of bins, and discard the last slice of the first dimension later. new_shape = (len(bin_edges),) + forecast.shape forecast_mask = np.broadcast_to( np.expand_dims(np.ma.getmask(forecast), 0), new_shape ) forecast_probabilities = np.zeros(new_shape, dtype=forecast.dtype) np.put_along_axis( forecast_probabilities, np.expand_dims(bin_index, 0), forecast, axis=0 ) forecast_probabilities = np.ma.array( forecast_probabilities, mask=forecast_mask, copy=False ) forecast_counts = np.zeros_like(forecast_probabilities) np.put_along_axis(forecast_counts, np.expand_dims(bin_index, 0), 1, axis=0) forecast_counts = np.ma.array(forecast_counts, mask=forecast_mask, copy=False) observation_counts = ( np.expand_dims(np.isclose(truth, 1), 0) & forecast_counts.astype(bool) ).astype(int) # discard last index in first dimension because it contains data from forecast nans reliability_table = np.ma.stack( [ observation_counts[:-1, :], forecast_probabilities[:-1, :], forecast_counts[:-1, :], ] ) return reliability_table.astype(np.float32)
[docs] def _populate_masked_reliability_bins( self, forecast: ndarray, truth: MaskedArray ) -> MaskedArray: """ Support populating the reliability table bins with a masked truth. If a masked truth is provided, a masked reliability table is returned. Args: forecast: An array containing data over an xy slice for a single validity time and threshold. truth: An array containing a thresholded gridded truth at an equivalent validity time to the forecast array. Returns: An array containing reliability table data for a single time and threshold. The leading dimension corresponds to the rows of a calibration table, the second dimension to the number of probability bins, and the trailing dimensions are the spatial dimensions of the forecast and truth cubes (which are equivalent). """ forecast = np.ma.masked_where(np.ma.getmask(truth), forecast) table = self._populate_reliability_bins(forecast, truth) # Zero data underneath mask to support bitwise addition of masks. table.data[table.mask] = 0 return table
[docs] def _add_reliability_tables( self, forecast: Cube, truth: Cube, threshold_reliability: MaskedArray ) -> Union[MaskedArray, ndarray]: """ Add reliability tables. The presence of a masked truth is handled separately to ensure support for a mask that changes with validity time. Args: forecast: An array containing data over an xy slice for a single validity time and threshold. truth: An array containing a thresholded gridded truth at an equivalent validity time to the forecast array. threshold_reliability: The current reliability table that will be added to. Returns: An array containing reliability table data for a single time and threshold. The leading dimension corresponds to the rows of a calibration table, the second dimension to the number of probability bins, and the trailing dimensions are the spatial dimensions of the forecast and truth cubes (which are equivalent). """ if np.ma.is_masked(truth.data): table = self._populate_masked_reliability_bins(forecast.data, truth.data) # Bitwise addition of masks. This ensures that only points that are # masked in both the existing and new reliability tables are kept # as being masked within the resulting reliability table. mask = threshold_reliability.mask & table.mask threshold_reliability = np.ma.array( threshold_reliability.data + table.data, mask=mask, dtype=np.float32 ) else: np.add( threshold_reliability, self._populate_reliability_bins(forecast.data, truth.data), out=threshold_reliability, dtype=np.float32, ) return threshold_reliability
[docs] def process( self, historic_forecasts: Cube, truths: Cube, aggregate_coords: Optional[List[str]] = None, ) -> Cube: """ Slice data over threshold and time coordinates to construct reliability tables. These are summed over time to give a single table for each threshold, constructed from all the provided historic forecasts and truths. If a masked truth is provided, a masked reliability table is returned. If the mask within the truth varies at different timesteps, any point that is unmasked for at least one timestep will have unmasked values within the reliability table. Therefore historic forecast points will only be used if they have a corresponding valid truth point for each timestep. .. See the documentation for an example of the resulting reliability table cube. .. include:: extended_documentation/calibration/ reliability_calibration/reliability_calibration_examples.rst Note that the forecast and truth data used is probabilistic, i.e. has already been thresholded relative to the thresholds of interest, using the equality operator required. As such this plugin is agnostic as to whether the data is thresholded below or above a given diagnostic threshold. `historic_forecasts` and `truths` should have matching validity times. Args: historic_forecasts: A cube containing the historical forecasts used in calibration. truths: A cube containing the thresholded gridded truths used in calibration. aggregate_coords: Coordinates to aggregate over during construction. This is equivalent to constructing then using :class:`improver.calibration.reliability_calibration.AggregateReliabilityCalibrationTables` but with reduced memory usage due to avoiding large intermediate data. Returns: A cubelist of reliability table cubes, one for each threshold in the historic forecast cubes. Raises: ValueError: If the forecast and truth cubes have differing threshold coordinates. """ historic_forecasts, truths = filter_non_matching_cubes( historic_forecasts, truths ) threshold_coord = find_threshold_coordinate(historic_forecasts) truth_threshold_coord = find_threshold_coordinate(truths) if not threshold_coord == truth_threshold_coord: msg = "Threshold coordinates differ between forecasts and truths." raise ValueError(msg) time_coord = historic_forecasts.coord("time") check_forecast_consistency(historic_forecasts) reliability_cube = self._create_reliability_table_cube( historic_forecasts, threshold_coord ) populate_bins_func = self._populate_reliability_bins if np.ma.is_masked(truths.data): populate_bins_func = self._populate_masked_reliability_bins reliability_tables = iris.cube.CubeList() threshold_slices = zip( historic_forecasts.slices_over(threshold_coord), truths.slices_over(threshold_coord), ) for forecast_slice, truth_slice in threshold_slices: time_slices = zip( forecast_slice.slices_over(time_coord), truth_slice.slices_over(time_coord), ) forecast, truth = next(time_slices) threshold_reliability = populate_bins_func(forecast.data, truth.data) for forecast, truth in time_slices: threshold_reliability = self._add_reliability_tables( forecast, truth, threshold_reliability ) reliability_entry = reliability_cube.copy(data=threshold_reliability) reliability_entry.replace_coord(forecast_slice.coord(threshold_coord)) if aggregate_coords: reliability_entry = AggregateReliabilityCalibrationTables().process( [reliability_entry], aggregate_coords ) reliability_tables.append(reliability_entry) return MergeCubes()(reliability_tables, copy=False)
[docs] class AggregateReliabilityCalibrationTables(BasePlugin): """This plugin enables the aggregation of multiple reliability calibration tables, and/or the aggregation over coordinates in the tables.""" def __repr__(self) -> str: """Represent the configured plugin instance as a string.""" return "<AggregateReliabilityCalibrationTables>"
[docs] @staticmethod def _check_frt_coord(cubes: Union[List[Cube], CubeList]) -> None: """ Check that the reliability calibration tables do not have overlapping forecast reference time bounds. If these coordinates overlap in time it indicates that some of the same forecast data has contributed to more than one table, thus aggregating them would double count these contributions. Args: cubes: The list of reliability calibration tables for which the forecast reference time coordinates should be checked. Raises: ValueError: If the bounds overlap. """ lower_bounds = [] upper_bounds = [] for cube in cubes: lower_bounds.append(cube.coord("forecast_reference_time").bounds[0][0]) upper_bounds.append(cube.coord("forecast_reference_time").bounds[0][1]) if not all(x < y for x, y in zip(upper_bounds, lower_bounds[1:])): raise ValueError( "Reliability calibration tables have overlapping " "forecast reference time bounds, indicating that " "the same forecast data has contributed to the " "construction of both tables. Cannot aggregate." )
[docs] def process( self, cubes: Union[CubeList, List[Cube]], coordinates: Optional[List[str]] = None, ) -> Cube: """ Aggregate the input reliability calibration table cubes and return the result. Args: cubes: The cube or cubes containing the reliability calibration tables to aggregate. coordinates: A list of coordinates over which to aggregate the reliability calibration table using summation. If the argument is None and a single cube is provided, this cube will be returned unchanged. Returns: Aggregated cube """ coordinates = [] if coordinates is None else coordinates try: (cube,) = cubes except ValueError: cubes = iris.cube.CubeList(cubes) self._check_frt_coord(cubes) cube = cubes.merge_cube() coordinates.append("forecast_reference_time") else: if not coordinates: return cube result = collapsed(cube, coordinates, iris.analysis.SUM) frt = create_unified_frt_coord(cube.coord("forecast_reference_time")) result.replace_coord(frt) return result
[docs] class ManipulateReliabilityTable(BasePlugin): """ A plugin to manipulate the reliability tables before they are used to calibrate a forecast. x and y coordinates on the reliability table must be collapsed. The result is a reliability diagram with monotonic observation frequency. Steps taken are: 1. If any bin contains less than the minimum forecast count then try combining this bin with whichever neighbour has the lowest sample count. This process is repeated for all bins that are below the minimum forecast count criterion. 2. If non-monotonicity of the observation frequency is detected, try combining a pair of bins that appear non-monotonic. Only a single pair of bins are combined. 3. If non-monotonicity of the observation frequency remains after trying to combine a single pair of bins, replace non-monotonic bins by assuming a constant observation frequency. """
[docs] def __init__( self, minimum_forecast_count: int = 200, point_by_point: bool = False ) -> None: """ Initialise class for manipulating a reliability table. Args: minimum_forecast_count: The minimum number of forecast counts in a forecast probability bin for it to be used in calibration. The default value of 200 is that used in Flowerdew 2014. point_by_point: Whether to process each point in the input cube independently. Please note this option is memory intensive and is unsuitable for gridded input Raises: ValueError: If minimum_forecast_count is less than 1. References: Flowerdew J. 2014. Calibrating ensemble reliability whilst preserving spatial structure. Tellus, Ser. A Dyn. Meteorol. Oceanogr. 66. """ if minimum_forecast_count < 1: raise ValueError( "The minimum_forecast_count must be at least 1 as empty " "bins in the reliability table are not handled." ) self.minimum_forecast_count = minimum_forecast_count self.point_by_point = point_by_point
[docs] @staticmethod def _extract_reliability_table_components( reliability_table: Cube, ) -> Tuple[ndarray, ndarray, ndarray, DimCoord]: """Extract reliability table components from cube Args: reliability_table: A reliability table to be manipulated. Returns: Tuple containing the updated observation count, forecast probability sum, forecast count and probability bin coordinate. """ observation_count = reliability_table.extract( iris.Constraint(table_row_name="observation_count") ).data forecast_probability_sum = reliability_table.extract( iris.Constraint(table_row_name="sum_of_forecast_probabilities") ).data forecast_count = reliability_table.extract( iris.Constraint(table_row_name="forecast_count") ).data probability_bin_coord = reliability_table.coord("probability_bin") return ( observation_count, forecast_probability_sum, forecast_count, probability_bin_coord, )
[docs] @staticmethod def _sum_pairs(array: ndarray, upper: int) -> ndarray: """ Returns a new array where a pair of values in the original array have been replaced by their sum. Combines the value in the upper index with the value in the upper-1 index. Args: array: Array to be modified. upper: Upper index of pair. Returns: Array where a pair of values has been replaced by their sum. """ result = array.copy() result[upper - 1] = np.sum(array[upper - 1 : upper + 1]) return np.delete(result, upper)
[docs] @staticmethod def _create_new_bin_coord(probability_bin_coord: DimCoord, upper: int) -> DimCoord: """ Create a new probability_bin coordinate by combining two adjacent points on the probability_bin coordinate. This matches the combination of the data for the two bins. Args: probability_bin_coord: Original probability bin coordinate. upper: Upper index of pair. Returns: Probability bin coordinate with updated points and bounds where a pair of bins have been combined to create a single bin. """ old_bounds = probability_bin_coord.bounds new_bounds = np.concatenate( ( old_bounds[0 : upper - 1], np.array([[old_bounds[upper - 1, 0], old_bounds[upper, 1]]]), old_bounds[upper + 1 :], ) ) new_points = np.mean(new_bounds, axis=1, dtype=np.float32) new_bin_coord = iris.coords.DimCoord( new_points, long_name="probability_bin", units=1, bounds=new_bounds ) return new_bin_coord
[docs] def _combine_undersampled_bins( self, observation_count: ndarray, forecast_probability_sum: ndarray, forecast_count: ndarray, probability_bin_coord: DimCoord, ) -> Tuple[ndarray, ndarray, ndarray, DimCoord]: """ Combine bins that are under-sampled i.e. that have a lower forecast count than the minimum_forecast_count, so that information from these poorly-sampled bins can contribute to the calibration. If multiple bins are below the minimum forecast count, the bin closest to meeting the minimum_forecast_count criterion is combined with whichever neighbour has the lowest sample count. A new bin is then created by summing the neighbouring pair of bins. This process is repeated for all bins that are below the minimum forecast count criterion. Args: observation_count: Observation count extracted from reliability table. forecast_probability_sum: Forecast probability sum extracted from reliability table. forecast_count: Forecast count extracted from reliability table. probability_bin_coord: Original probability bin coordinate. Returns: Tuple containing the updated observation count, forecast probability sum, forecast count and probability bin coordinate. """ while ( any(x < self.minimum_forecast_count for x in forecast_count) and len(forecast_count) > 1 ): forecast_count_copy = forecast_count.copy() # Find index of the bin with the highest forecast count that is # below the minimum_forecast_count by setting forecast counts # greater than the minimum_forecast_count to NaN. forecast_count_copy[forecast_count >= self.minimum_forecast_count] = np.nan # Note for multiple occurrences of the maximum, # the index of the first occurrence is returned. index = np.int32(np.nanargmax(forecast_count_copy)) # Determine the upper index of the pair of bins to be combined. if index == 0: # Must use higher bin upper = index + 1 elif index + 1 == len(forecast_count): # Index already defines the upper bin upper = index else: # Define upper index to include bin with lowest sample count. if forecast_count[index + 1] > forecast_count[index - 1]: upper = index else: upper = index + 1 forecast_count = self._sum_pairs(forecast_count, upper) observation_count = self._sum_pairs(observation_count, upper) forecast_probability_sum = self._sum_pairs(forecast_probability_sum, upper) probability_bin_coord = self._create_new_bin_coord( probability_bin_coord, upper ) return ( observation_count, forecast_probability_sum, forecast_count, probability_bin_coord, )
[docs] def _combine_bin_pair( self, observation_count: ndarray, forecast_probability_sum: ndarray, forecast_count: ndarray, probability_bin_coord: DimCoord, ) -> Tuple[ndarray, ndarray, ndarray, DimCoord]: """ Combine a pair of bins when non-monotonicity of the observation frequency is detected. Iterate top-down from the highest forecast probability bin to the lowest probability bin when combining the bins. Only allow a single pair of bins to be combined. Args: observation_count: Observation count extracted from reliability table. forecast_probability_sum: Forecast probability sum extracted from reliability table. forecast_count: Forecast count extracted from reliability table. probability_bin_coord: Original probability bin coordinate. Returns: Tuple containing the updated observation count, forecast probability sum, forecast count and probability bin coordinate. """ observation_frequency = np.array(observation_count / forecast_count) for upper in np.arange(len(observation_frequency) - 1, 0, -1): (diff,) = np.diff( [observation_frequency[upper - 1], observation_frequency[upper]] ) if diff < 0: forecast_count = self._sum_pairs(forecast_count, upper) observation_count = self._sum_pairs(observation_count, upper) forecast_probability_sum = self._sum_pairs( forecast_probability_sum, upper ) probability_bin_coord = self._create_new_bin_coord( probability_bin_coord, upper ) break return ( observation_count, forecast_probability_sum, forecast_count, probability_bin_coord, )
[docs] @staticmethod def _assume_constant_observation_frequency( observation_count: ndarray, forecast_count: ndarray ) -> ndarray: """ Decide which end bin (highest probability bin or lowest probability bin) has the highest sample count. Iterate through the observation frequency from the end bin with the highest sample count to the end bin with the lowest sample count. Whilst iterating, compare each pair of bins and, if a pair is non-monotonic, replace the value of the bin closer to the lowest sample count end bin with the value of the bin that is closer to the higher sample count end bin. Then calculate the new observation count required to give a monotonic observation frequency. Args: observation_count: Observation count extracted from reliability table. forecast_count: Forecast count extracted from reliability table. Returns: Observation count computed from a monotonic observation frequency. """ observation_frequency = np.array(observation_count / forecast_count) iterator = observation_frequency operation = operator.lt # Top down if forecast count is lower for lowest probability bin, # than for highest probability bin. if forecast_count[0] < forecast_count[-1]: # Reverse array to iterate from top to bottom. iterator = observation_frequency[::-1] operation = operator.gt for index, lower_bin in enumerate(iterator[:-1]): (diff,) = np.diff([lower_bin, iterator[index + 1]]) if operation(diff, 0): iterator[index + 1] = lower_bin observation_frequency = iterator if forecast_count[0] < forecast_count[-1]: # Re-reverse array from bottom to top to ensure original ordering. observation_frequency = iterator[::-1] observation_count = observation_frequency * forecast_count return observation_count
[docs] @staticmethod def _update_reliability_table( reliability_table: Cube, observation_count: ndarray, forecast_probability_sum: ndarray, forecast_count: ndarray, probability_bin_coord: DimCoord, ) -> Cube: """ Update the reliability table data and the probability bin coordinate. Args: reliability_table: A reliability table to be manipulated. observation_count: Observation count extracted from reliability table. forecast_probability_sum: Forecast probability sum extracted from reliability table. forecast_count: Forecast count extracted from reliability table. probability_bin_coord: Original probability bin coordinate. Returns: Updated reliability table. """ final_data = np.stack( [observation_count, forecast_probability_sum, forecast_count] ) nrows, ncols = final_data.shape reliability_table = reliability_table[0:nrows, 0:ncols].copy(data=final_data) reliability_table.replace_coord(probability_bin_coord) return reliability_table
[docs] def _enforce_min_count_and_montonicity(self, rel_table_slice: Cube) -> Cube: """Apply the steps needed to produce a reliability diagram on a single slice of reliability table cube. Args: reliability_table_slice: The reliability table slice to be manipulated. The only coordinates expected on this cube are a table_row_index coordinate and corresponding table_row_name coordinate and a probability_bin coordinate. Returns: Processed reliability table slice, with reliability steps applied. """ ( observation_count, forecast_probability_sum, forecast_count, probability_bin_coord, ) = self._extract_reliability_table_components(rel_table_slice) if np.any(forecast_count < self.minimum_forecast_count): ( observation_count, forecast_probability_sum, forecast_count, probability_bin_coord, ) = self._combine_undersampled_bins( observation_count, forecast_probability_sum, forecast_count, probability_bin_coord, ) rel_table_slice = self._update_reliability_table( rel_table_slice, observation_count, forecast_probability_sum, forecast_count, probability_bin_coord, ) # If the observation frequency is non-monotonic adjust the # reliability table observation_frequency = np.array(observation_count / forecast_count) if not np.all(np.diff(observation_frequency) >= 0): ( observation_count, forecast_probability_sum, forecast_count, probability_bin_coord, ) = self._combine_bin_pair( observation_count, forecast_probability_sum, forecast_count, probability_bin_coord, ) observation_count = self._assume_constant_observation_frequency( observation_count, forecast_count ) rel_table_slice = self._update_reliability_table( rel_table_slice, observation_count, forecast_probability_sum, forecast_count, probability_bin_coord, ) return rel_table_slice
[docs] def process(self, reliability_table: Cube) -> CubeList: """ Apply the steps needed to produce a reliability diagram with a monotonic observation frequency. Args: reliability_table: A reliability table to be manipulated. The only coordinates expected on this cube are a threshold coordinate, a table_row_index coordinate and corresponding table_row_name coordinate and a probability_bin coordinate. Returns: CubeList containing a reliability table cube for each threshold in the input reliablity table. For tables where monotonicity has been enforced the probability_bin coordinate will have one less bin than the tables that were already monotonic. If under-sampled bins have been combined, then the probability_bin coordinate will have been reduced until all bins have more than the minimum_forecast_count if possible; a single under-sampled bin will be returned if combining all bins is still insufficient to reach the minimum_forecast_count. """ threshold_coord = find_threshold_coordinate(reliability_table) if self.point_by_point: y_name = reliability_table.coord(axis="y").name() x_name = reliability_table.coord(axis="x").name() reliability_table_cubelist = iris.cube.CubeList() for rel_table_threshold in reliability_table.slices_over(threshold_coord): if self.point_by_point: for rel_table_point in rel_table_threshold.slices_over( [y_name, x_name] ): rel_table_point_emcam = self._enforce_min_count_and_montonicity( rel_table_point ) reliability_table_cubelist.append(rel_table_point_emcam) else: rel_table_processed = self._enforce_min_count_and_montonicity( rel_table_threshold ) reliability_table_cubelist.append(rel_table_processed) return reliability_table_cubelist
[docs] class ApplyReliabilityCalibration(PostProcessingPlugin): """ A plugin for the application of reliability calibration to probability forecasts. This calibration is designed to improve the reliability of probability forecasts without significantly degrading their resolution. The method implemented here is described in Flowerdew J. 2014. Calibration is always applied as long as there are at least two bins within the input reliability table. References: Flowerdew J. 2014. Calibrating ensemble reliability whilst preserving spatial structure. Tellus, Ser. A Dyn. Meteorol. Oceanogr. 66. """
[docs] def __init__(self, point_by_point: bool = False) -> None: """ Initialise class for applying reliability calibration. Args: point_by_point: Whether to calibrate each point in the input cube independently. Utilising this option requires that each spatial point in the forecast cube has a corresponding spatial point in the reliability table. Please note this option is memory intensive and is unsuitable for gridded input. """ self.threshold_coord = None self.point_by_point = point_by_point
[docs] @staticmethod def _extract_matching_reliability_table( forecast: Cube, reliability_table: Union[Cube, CubeList] ) -> Cube: """ Extract the reliability table with a threshold coordinate matching the forecast cube. If no matching reliability table is found raise an exception. Args: forecast: The forecast to be calibrated. reliability_table: The reliability table to use for applying calibration. Returns: A reliability table with a threshold coordinate that matches the forecast cube. Raises: ValueError: If no matching reliability table is found. """ threshold_coord = find_threshold_coordinate(forecast) coord_values = {threshold_coord.name(): threshold_coord.points} constr = iris.Constraint(coord_values=coord_values) if isinstance(reliability_table, iris.cube.Cube): extracted = reliability_table.extract(constr) else: extracted = reliability_table.extract_cube(constr) if not extracted: raise ValueError( "No reliability table found to match threshold " f"{find_threshold_coordinate(forecast).points[0]}." ) return extracted
[docs] def _ensure_monotonicity_across_thresholds(self, cube: Cube) -> None: """ Ensures that probabilities change monotonically relative to thresholds in the expected order, e.g. exceedance probabilities always remain the same or decrease as the threshold values increase, below threshold probabilities always remain the same or increase as the threshold values increase. Args: cube: The probability cube for which monotonicity is to be checked and enforced. This cube is modified in place. Raises: ValueError: Threshold coordinate lacks the spp__relative_to_threshold attribute. Warns: UserWarning: If the probabilities must be sorted to reinstate expected monotonicity following calibration. """ if not cube.coord_dims(self.threshold_coord): return (threshold_dim,) = cube.coord_dims(self.threshold_coord) thresholding = probability_is_above_or_below(cube) if thresholding is None: msg = ( "Cube threshold coordinate does not define whether " "thresholding is above or below the defined thresholds." ) raise ValueError(msg) if ( thresholding == "above" and not (np.diff(cube.data, axis=threshold_dim) <= 0).all() ): msg = ( "Exceedance probabilities are not decreasing monotonically " "as the threshold values increase. Forced back into order." ) warnings.warn(msg) cube.data = np.sort(cube.data, axis=threshold_dim)[::-1] if ( thresholding == "below" and not (np.diff(cube.data, axis=threshold_dim) >= 0).all() ): msg = ( "Below threshold probabilities are not increasing " "monotonically as the threshold values increase. Forced " "back into order." ) warnings.warn(msg) cube.data = np.sort(cube.data, axis=threshold_dim)
[docs] def _calculate_reliability_probabilities( self, reliability_table: Cube ) -> Tuple[Optional[ndarray], Optional[ndarray]]: """ Calculates forecast probabilities and observation frequencies from the reliability table. If fewer than two bins are provided, Nones are returned as no calibration can be applied. Fewer than two bins can occur due to repeated combination of undersampled probability bins, please see :class:`.ManipulateReliabilityTable`. Args: reliability_table: A reliability table for a single threshold from which to calculate the forecast probabilities and observation frequencies. Returns: Tuple containing forecast probabilities calculated by dividing the sum of forecast probabilities by the forecast count and observation frequency calculated by dividing the observation count by the forecast count. """ observation_count = reliability_table.extract( iris.Constraint(table_row_name="observation_count") ).data forecast_count = reliability_table.extract( iris.Constraint(table_row_name="forecast_count") ).data forecast_probability_sum = reliability_table.extract( iris.Constraint(table_row_name="sum_of_forecast_probabilities") ).data # If there are fewer than two bins, no calibration can be applied. if len(np.atleast_1d(forecast_count)) < 2: return None, None forecast_probability = np.array(forecast_probability_sum / forecast_count) observation_frequency = np.array(observation_count / forecast_count) return forecast_probability, observation_frequency
[docs] @staticmethod def _interpolate( forecast_threshold: Union[MaskedArray, ndarray], reliability_probabilities: ndarray, observation_frequencies: ndarray, ) -> Union[MaskedArray, ndarray]: """ Perform interpolation of the forecast probabilities using the reliability table data to produce the calibrated forecast. Where necessary linear extrapolation will be applied. Any mask in place on the forecast_threshold data is removed and reapplied after calibration. Args: forecast_threshold: The forecast probabilities to be calibrated. reliability_probabilities: Probabilities taken from the reliability tables. observation_frequencies: Observation frequencies that relate to the reliability probabilities, taken from the reliability tables. Returns: The calibrated forecast probabilities. The final results are clipped to ensure any extrapolation has not yielded probabilities outside the range 0 to 1. """ shape = forecast_threshold.shape mask = forecast_threshold.mask if np.ma.is_masked(forecast_threshold) else None forecast_probabilities = np.ma.getdata(forecast_threshold).flatten() # Interpolate using scipy first to get extrapolated values at endpoints # since np.interp does not allow extrapolation. We would need to change back # to scipy.interpolate if we want non-linear interpolation in future. interpolation_function = scipy.interpolate.interp1d( reliability_probabilities, observation_frequencies, fill_value="extrapolate" ) y_0, y_1 = interpolation_function([0, 1]) xp = np.copy(reliability_probabilities) # Extrapolation preserves the slope of the first and last segments of the piecewise # linear function. Thus the slope betweeen [0, y_0] and [xp[0], fp[0]] is the same as that # between [xp[0], fp[0]] and [xp[1], fp[1]], so we can replace # [xp[0], fp[0]] with [0, y_0] to extend the width of the first segment of the # piecewise linear function. A similar argument applies for the last segment. xp[0] = 0 xp[-1] = 1 fp = np.copy(observation_frequencies) fp[0] = y_0 fp[-1] = y_1 interpolated = np.interp(forecast_probabilities.data, xp, fp) interpolated = interpolated.reshape(shape).astype(np.float32) if mask is not None: interpolated = np.ma.masked_array(interpolated, mask=mask) return np.clip(interpolated, 0, 1)
[docs] def _apply_calibration( self, forecast: Cube, reliability_table: Union[Cube, CubeList] ) -> Cube: """ Apply reliability calibration to a forecast. Args: forecast: The forecast to be calibrated. reliability_table: The reliability table to use for applying calibration. Returns: The forecast cube following calibration. """ calibrated_cubes = iris.cube.CubeList() forecast_thresholds = forecast.slices_over(self.threshold_coord) uncalibrated_thresholds = [] for forecast_threshold in forecast_thresholds: reliability_threshold = self._extract_matching_reliability_table( forecast_threshold, reliability_table ) ( reliability_probabilities, observation_frequencies, ) = self._calculate_reliability_probabilities(reliability_threshold) if reliability_probabilities is None: calibrated_cubes.append(forecast_threshold) uncalibrated_thresholds.append( forecast_threshold.coord(self.threshold_coord).points[0] ) continue interpolated = self._interpolate( forecast_threshold.data, reliability_probabilities, observation_frequencies, ) calibrated_cubes.append(forecast_threshold.copy(data=interpolated)) calibrated_forecast = calibrated_cubes.merge_cube() self._ensure_monotonicity_across_thresholds(calibrated_forecast) if uncalibrated_thresholds: uncalibrated_thresholds = list(map(float, uncalibrated_thresholds)) msg = ( "The following thresholds were not calibrated due to " "insufficient forecast counts in reliability table bins: " "{}".format(uncalibrated_thresholds) ) warnings.warn(msg) return calibrated_forecast
[docs] def _apply_point_by_point_calibration( self, forecast: Cube, reliability_table: CubeList ) -> Cube: """ Apply point by point reliability calibration by iteratively picking a spatial coordinate within the forecast cube, extracting the forecast at that point and the reliability table corresponding to that point, then passing the extracted forecast and reliability table to _get_calibrated_forecast(). Args: forecast: The forecast to be calibrated. reliability_table: The reliability table to use for applying calibration. Returns: The forecast cube following calibration. """ calibrated_cubes = iris.cube.CubeList() y_name = forecast.coord(axis="y").name() x_name = forecast.coord(axis="x").name() # create list of dimensions dim_names = get_dim_coord_names(forecast) dim_associated_coords = {} # create dictionary with dimension name as keys, containing auxiliary # coordinates associated with that dimension. for dim_index, dim_name in enumerate(dim_names): associated_coords = [ c for c in forecast.coords(dimensions=dim_index, dim_coords=False) ] dim_associated_coords[dim_name] = associated_coords # slice over the spatial dimension/s of the forecast cube # and apply reliability calibration separately to each slice # using a slice of the input reliability table at the same # spatial point for forecast_point in forecast.slices_over([y_name, x_name]): y_point = forecast_point.coord(y_name).points[0] x_point = forecast_point.coord(x_name).points[0] # create reliability table containing only those cubes # relating to the currently considered spatial point reliability_table_point = reliability_table.extract( iris.Constraint(coord_values={y_name: y_point, x_name: x_point}) ) calibrated_cube = self._apply_calibration( forecast=forecast_point, reliability_table=reliability_table_point ) # remove auxiliary coordinates to ensure cubes can be merged into initial # format later for coords in dim_associated_coords.values(): for coord in coords: calibrated_cube.remove_coord(coord.name()) calibrated_cubes.append(calibrated_cube) calibrated_forecast = calibrated_cubes.merge_cube() # add auxiliary coordinates back to the calibrated cube for dim_coord in dim_associated_coords.keys(): for coord in dim_associated_coords[dim_coord]: dim = [x for x in calibrated_forecast.coord_dims(dim_coord)] calibrated_forecast.add_aux_coord(coord, dim) # ensure that calibrated forecast dimensions are in the same # order as the dimensions in the input forecast enforce_coordinate_ordering(calibrated_forecast, dim_names) return calibrated_forecast
[docs] def process(self, forecast: Cube, reliability_table: Union[Cube, CubeList]) -> Cube: """ Apply reliability calibration to a forecast. The reliability table and the forecast cube must share an identical threshold coordinate. Args: forecast: The forecast to be calibrated. reliability_table: The reliability table to use for applying calibration. Returns: The forecast cube following calibration. """ self.threshold_coord = find_threshold_coordinate(forecast) if self.point_by_point: calibrated_forecast = self._apply_point_by_point_calibration( forecast=forecast, reliability_table=reliability_table ) else: calibrated_forecast = self._apply_calibration( forecast=forecast, reliability_table=reliability_table ) # enforce correct data type calibrated_forecast.data = calibrated_forecast.data.astype("float32") return calibrated_forecast