"""Base Multivariate class."""
import pickle
import numpy as np
from copulas import NotFittedError, get_instance, validate_random_state
[docs]class Multivariate(object):
"""Abstract class for a multi-variate copula object."""
fitted = False
def __init__(self, random_state=None):
self.random_state = validate_random_state(random_state)
[docs] def fit(self, X):
"""Fit the model to table with values from multiple random variables.
Arguments:
X (pandas.DataFrame):
Values of the random variables.
"""
raise NotImplementedError
[docs] def probability_density(self, X):
"""Compute the probability density for each point in X.
Arguments:
X (pandas.DataFrame):
Values for which the probability density will be computed.
Returns:
numpy.ndarray:
Probability density values for points in X.
Raises:
NotFittedError:
if the model is not fitted.
"""
raise NotImplementedError
[docs] def log_probability_density(self, X):
"""Compute the log of the probability density for each point in X.
Arguments:
X (pandas.DataFrame):
Values for which the log probability density will be computed.
Returns:
numpy.ndarray:
Log probability density values for points in X.
Raises:
NotFittedError:
if the model is not fitted.
"""
return np.log(self.probability_density(X))
[docs] def pdf(self, X):
"""Compute the probability density for each point in X.
Arguments:
X (pandas.DataFrame):
Values for which the probability density will be computed.
Returns:
numpy.ndarray:
Probability density values for points in X.
Raises:
NotFittedError:
if the model is not fitted.
"""
return self.probability_density(X)
[docs] def cumulative_distribution(self, X):
"""Compute the cumulative distribution value for each point in X.
Arguments:
X (pandas.DataFrame):
Values for which the cumulative distribution will be computed.
Returns:
numpy.ndarray:
Cumulative distribution values for points in X.
Raises:
NotFittedError:
if the model is not fitted.
"""
raise NotImplementedError
[docs] def cdf(self, X):
"""Compute the cumulative distribution value for each point in X.
Arguments:
X (pandas.DataFrame):
Values for which the cumulative distribution will be computed.
Returns:
numpy.ndarray:
Cumulative distribution values for points in X.
Raises:
NotFittedError:
if the model is not fitted.
"""
return self.cumulative_distribution(X)
[docs] def set_random_state(self, random_state):
"""Set the random state.
Args:
random_state (int, np.random.RandomState, or None):
Seed or RandomState for the random generator.
"""
self.random_state = validate_random_state(random_state)
[docs] def sample(self, num_rows=1):
"""Sample values from this model.
Argument:
num_rows (int):
Number of rows to sample.
Returns:
numpy.ndarray:
Array of shape (n_samples, *) with values randomly
sampled from this model distribution.
Raises:
NotFittedError:
if the model is not fitted.
"""
raise NotImplementedError
[docs] def to_dict(self):
"""Return a `dict` with the parameters to replicate this object.
Returns:
dict:
Parameters of this distribution.
"""
raise NotImplementedError
[docs] @classmethod
def from_dict(cls, params):
"""Create a new instance from a parameters dictionary.
Args:
params (dict):
Parameters of the distribution, in the same format as the one
returned by the ``to_dict`` method.
Returns:
Multivariate:
Instance of the distribution defined on the parameters.
"""
multivariate_class = get_instance(params['type'])
return multivariate_class.from_dict(params)
[docs] @classmethod
def load(cls, path):
"""Load a Multivariate instance from a pickle file.
Args:
path (str):
Path to the pickle file where the distribution has been serialized.
Returns:
Multivariate:
Loaded instance.
"""
with open(path, 'rb') as pickle_file:
return pickle.load(pickle_file)
[docs] def save(self, path):
"""Serialize this multivariate instance using pickle.
Args:
path (str):
Path to where this distribution will be serialized.
"""
with open(path, 'wb') as pickle_file:
pickle.dump(self, pickle_file)
[docs] def check_fit(self):
"""Check whether this model has already been fit to a random variable.
Raise a ``NotFittedError`` if it has not.
Raises:
NotFittedError:
if the model is not fitted.
"""
if not self.fitted:
raise NotFittedError('This model is not fitted.')