Source code for improver.cli.estimate_samos_gams

#!/usr/bin/env python
# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
"""CLI to estimate the Generalized Additive Model (GAM) for Standardized Anomaly Model
Output Statistics (SAMOS)."""

from improver import cli


[docs] @cli.clizefy @cli.with_output def process( *cubes: cli.inputcube, truth_attribute: str, gam_features: cli.comma_separated_list, model_specification: cli.inputjson, max_iterations: int = 100, tolerance: float = 0.0001, distribution: str = "normal", link: str = "identity", fit_intercept: bool = True, window_length: int = 11, unique_site_id_key: str = "wmo_id", ): """Estimate Generalized Additive Model (GAM) for SAMOS. Args: cubes (list of iris.cube.Cube): A list of cubes containing the historical forecasts and corresponding truth used for calibration. They must have the same cube name and will be separated based on the truth attribute. The list may also contain additional features (static predictors) that will be provided when estimating the GAM. truth_attribute (str): An attribute and its value in the format of "attribute=value", which must be present on historical truth cubes. gam_features (list of str): A list of the names of the cubes that will be used as additional features in the GAM. model_specification (dict): A list containing three items (in order): 1. a string containing a single pyGAM term; one of 'l' (linear), 's' (spline), 'te' (tensor), or 'f' (factor) 2. a list of integers which correspond to the features to be included in that term 3. a dictionary of kwargs to be included when defining the term max_iterations (int): The maximum number of iterations to use when estimating the GAM coefficients. tolerance (float): The tolerance for the stopping criteria. distribution (str): The distribution to be used in the model. Valid options are normal, binomial, poisson, gamma, inv-gauss. link (str): The link function to be used in the model. Valid options are identity, logit, inverse, log or inverse-squared. fit_intercept (bool): Whether to include an intercept term in the model. Default is True. window_length (int): This must be an odd integer greater than 1. The length of the rolling window used to calculate the mean and standard deviation of the input cube when the input cube does not have a realization dimension coordinate. If a given window has fewer than half valid data points (not NaN) then the value returned for that window will be NaN and will be excluded from training. unique_site_id_key (str): If working with spot data and available, the name of the coordinate in the input cubes that contains unique site IDs, e.g. "wmo_id" if all sites have a valid wmo_id. For GAM estimation the default is "wmo_id" as we expect to have a training data set comprising matched obs and forecast sites. Returns: List: A list containing the fitted GAMs for the forecast and truth cubes in that order. """ from improver.calibration import split_cubes_for_samos from improver.calibration.samos_calibration import TrainGAMsForSAMOS # Split the cubes into forecast and truth cubes, along with any additional fields # provided for the GAMs. ( forecast, truth, gam_additional_fields, _, _, _, ) = split_cubes_for_samos( cubes=cubes, gam_features=gam_features, truth_attribute=truth_attribute, expect_emos_coeffs=False, expect_emos_fields=False, ) if forecast is None or truth is None: return plugin = TrainGAMsForSAMOS( model_specification=model_specification, max_iter=max_iterations, tol=tolerance, distribution=distribution, link=link, fit_intercept=fit_intercept, window_length=window_length, unique_site_id_key=unique_site_id_key, ) truth_gams = plugin.process( input_cube=truth, features=gam_features, additional_fields=gam_additional_fields, ) forecast_gams = plugin.process( input_cube=forecast, features=gam_features, additional_fields=gam_additional_fields, ) if forecast_gams is None or truth_gams is None: return return [forecast_gams, truth_gams]