Source code for improver.utilities.generalized_additive_models
# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
"""Module to contain methods for fitting and predicting using generalized additive
models."""
import warnings
from copy import deepcopy
from typing import List
import numpy as np
from improver import BasePlugin
[docs]
class GAMFit(BasePlugin):
"""
Class for fitting Generalized Additive Models (GAMs) which predict the mean or
standard deviation of input forecasts or observations.
This class uses functionality from pyGAM
(https://pygam.readthedocs.io/en/latest/index.html) to fit the model.
"""
[docs]
def __init__(
self,
model_specification: List,
max_iter: int = 100,
tol: float = 0.0001,
distribution: str = "normal",
link: str = "identity",
fit_intercept: bool = True,
):
"""
Initialize class for fitting GAMs using pyGAM.
Args:
model_specification:
A list containing lists of three items (in order):
1. a string containing a single pyGAM term; one of 'linear',
'spline', 'tensor', or 'factor'
2. a list of indices of the features to be included in that term,
corresponding to the index of those features in the predictor array
3. a dictionary of kwargs to be included when defining the term
max_iter:
A pyGAM argument which determines the maximum iterations allowed when
fitting the GAM. Defaults to 100.
tol:
A pyGAM argument determining the tolerance used to define the stopping
criteria. Defaults to 0.0001.
distribution:
A pyGAM argument determining the distribution to be used in the model.
The default is a normal distribution.
link:
A pyGAM argument determining the link function to be used in the model.
Defaults to the identity link function, which implies a direct
relationship between predictors and target.
fit_intercept:
A pyGAM argument determining whether to include an intercept term in
the model. Default is True.
"""
self.model_specification = model_specification
self.max_iter = max_iter
self.tol = tol
self.distribution = distribution
self.link = link
self.fit_intercept = fit_intercept
[docs]
def create_pygam_model(self):
"""
Create a GAM model using pyGAM from the model_specification dictionary.
Returns:
GAM model equation constructed using pyGAM model terms.
"""
# Import from pygam here to minimize dependencies
from pygam import f, l, s, te
term = {
"factor": f,
"linear": l,
"spline": s,
"tensor": te,
} # create dictionary of permissible pyGAM model terms
for index, config in enumerate(self.model_specification):
# For each config in the list, parse the config to create a pyGAM term
# from that config. The first term in the config defines the type of term,
# the second defines which variables are included in that term, and the
# third contains a dictionary of kwargs.
if config[0] in term.keys():
new_term = term[config[0]](*config[1], **config[2])
else:
msg = (
f"An unrecognised term has been included in the GAM model "
f"specification. The term was {config[0]}, the accepted terms are "
f"linear, spline, tensor, factor."
)
raise ValueError(msg)
if index == 0:
# Initialize the equation variable
eqn = deepcopy(new_term)
else:
# Add new term to the existing equation
eqn += new_term
return eqn
[docs]
def process(self, predictors: np.ndarray, targets: np.ndarray):
"""
Fit a GAM model using pyGAM.
Args:
predictors: A 2-D array of predictors. The index of each column (feature) is
used in model_specification to determine which feature is included in each
model term.
targets: A 1-D array of target values associated with the predictors.
Returns:
A fitted pyGAM GAM model.
"""
# Monkey patch for pyGAM due to handling of sparse arrays in some versions of
# scipy.
import scipy.sparse
def to_array(self):
return self.toarray()
scipy.sparse.spmatrix.A = property(to_array)
# Import from pygam here to minimize dependencies
from pygam import GAM
# Remove nans from arrays.
predictors = predictors[~np.isnan(targets)]
targets = targets[~np.isnan(targets)]
# Check that there is data to fit after removing nans.
if (len(predictors) == 0) or (len(targets) == 0):
msg = (
"After removing NaN values from the input data, there are no "
"remaining data points to fit the GAM model. No model has been fitted."
)
warnings.warn(msg)
return None
eqn = self.create_pygam_model()
gam = GAM(
eqn,
max_iter=self.max_iter,
tol=self.tol,
distribution=self.distribution,
link=self.link,
fit_intercept=self.fit_intercept,
).fit(predictors, targets)
return gam
[docs]
class GAMPredict(BasePlugin):
"""Class for predicting new outputs from a fitted GAM given new input predictors."""
[docs]
def __init__(self):
"""Initialize class"""
[docs]
def process(self, gam, predictors: np.ndarray) -> np.ndarray:
"""
Use pyGAM functionality to predict values from a fitted GAM.
Args:
gam: A fitted pyGAM GAM model.
predictors: A 2-D array of inputs to use to predict new values. Each
feature (column) should have the same index as in the training dataset.
Returns:
A 1-D array of values predicted by the GAM with each value in the array
corresponding to one row in the input predictors.
"""
return gam.predict(predictors)