diff --git a/.gitignore b/.gitignore index a0185fcb23b7182ef37b214ad0bb3443d3786ac7..c5443b2cfce03389b9320db4ead6e8acc280ac0d 100644 --- a/.gitignore +++ b/.gitignore @@ -150,9 +150,16 @@ test.py ## directories that outputs when running the tests tests/Pln* +tests/ZIPln* slides/ index.html -tests/test_models* -tests/test_load* -tests/test_readme* +paper/* + + +tests/docstrings_examples/* +tests/getting_started/* +tests/readme_examples/* +# tests/test_getting_started.py +Getting_started.py +new_model.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d8c20b0f6ba469cda6260c8d4f2cded6f2d6a55f..e45adf607f75106844c73ed25d3faccaf8f91d60 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -15,14 +15,19 @@ black: tests: stage: checks - image: "registry.forgemia.inra.fr/jbleger/docker-image-pandas-torch-sphinx:master" + image: "registry.forgemia.inra.fr/bbatardiere/docker-image-pandas-torch-sphinx-jupyter" before_script: pip install '.[tests]' script: - pip install . + - jupyter nbconvert Getting_started.ipynb --to python --output tests/test_getting_started - cd tests - - python create_readme_and_docstrings_tests.py + - python create_readme_getting_started_and_docstrings_tests.py + - rm test_getting_started.py - pytest . + only: + - main + - dev build_package: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index faf13f3bb1a194b3f8966a6f68643464dc7de495..2f718217e57d26b92932f72b09b3f2f26029b7a8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,15 +1,121 @@ -# Clone the repo +# What to work on +A public roadmap will be available soon. + + +## Fork/clone/pull + +The typical workflow for contributing is: + +1. Fork the `main` branch from the [GitLab repository](https://forgemia.inra.fr/bbatardiere/pyplnmodels). +2. Clone your fork locally. +3. Run `pip install pre-commit` if pre-commit is not already installed. +4. Inside the repository, run 'pre-commit install'. +5. Commit changes. +6. Push the changes to your fork. +7. Send a pull request from your fork back to the original `main` branch. + +## How to implement a new model +You can implement a new model `newmodel` by inheriting from the abstract `_model` class in the `models` module. +The `newmodel` class should contains at least the following code: ``` -git clone git@forgemia.inra.fr:bbatardiere/pyplnmodels +class newmodel(_model): + _NAME="" + @property + def latent_variables(self) -> torch.Tensor: + "Implement here" + + def compute_elbo(self) -> torch.Tensor: + "Implement here" + + def _compute_elbo_b(self) -> torch.Tensor: + "Implement here" + + def _smart_init_model_parameters(self)-> None: + "Implement here" + + def _random_init_model_parameters(self)-> None: + "Implement here" + + def _smart_init_latent_parameters(self)-> None: + "Implement here" + + def _random_init_latent_parameters(self)-> None: + "Implement here" + + @property + def _list_of_parameters_needing_gradient(self)-> list: + "Implement here" + @property + def _description(self)-> str: + "Implement here" + + @property + def number_of_parameters(self) -> int: + "Implement here" + + @property + def model_parameters(self)-> Dict[str, torch.Tensor]: + "Implement here" + + @property + def latent_parameters(self)-> Dict[str, torch.Tensor]: + "Implement here" ``` +Each value of the 'latent_parameters' dict should be implemented (and protected) both in the +`_random_init_latent_parameters` and '_smart_init_latent_parameters'. +Each value of the 'model_parameters' dict should be implemented (and protected) both in the +`_random_init_model_parameters` and '_smart_init_model_parameters'. +For example, if you have one model parameters `coef` and latent_parameters `latent_mean` and `latent_var`, you should implement such as +```py +class newmodel(_model): + @property + def model_parameters(self) -> Dict[str, torch.Tensor]: + return {"coef":self.coef} + @property + def latent_parameters(self) -> Dict[str, torch.Tensor]: + return {"latent_mean":self.latent_mean, "latent_var":self.latent_var} + + def _random_init_latent_parameters(self): + self._latent_mean = init_latent_mean() + self._latent_var = init_latent_var() + + @property + def _smart_init_model_parameters(self): + self._latent_mean = random_init_latent_mean() + self._latent_var = random_init_latent_var() + + @property + def latent_var(self): + return self._latent_var -# Install precommit + @property + def latent_mean(self): + return self._latent_mean -In the directory: + def _random_init_model_parameters(self): + self._coef = init_coef() + def _smart_init_model_parameters(self): + self._coef = random_init_latent_coef() + + @property + def coef(self): + return self._coef ``` -pre-commit install + + + +Then, add `newmodel` in the `__init__.py` file of the pyPLNmodels module. +If `newmodel` is well implemented, running ``` +from pyPLNmodels import newmodel, get_real_count_data -If not found use `pip install pre-commit` before this command. +endog = get_real_count_data() +model = newmodel(endog, add_const = True) +model.fit(nb_max_iteration = 10, tol = 0) +``` +should increase the elbo of the model. You should document your functions with +[numpy-style +docstrings](https://numpydoc.readthedocs.io/en/latest/format.html). You can use +the `_add_doc` decorator (implemented in the `_utils` module) to inherit the docstrings of the `_model` class. diff --git a/README.md b/README.md index 9401cfe60161c2e62f5626201a4e674d06f3721e..f8adaa3f20d8c0665962d4b33ace865a8718206a 100644 --- a/README.md +++ b/README.md @@ -16,22 +16,10 @@ <!-- > slides](https://pln-team.github.io/slideshow/) for a --> <!-- > comprehensive introduction. --> -## Getting started -The getting started can be found [here](https://forgemia.inra.fr/bbatardiere/pyplnmodels/-/raw/dev/Getting_started.ipynb?inline=false). If you need just a quick view of the package, see next. +## Getting started +The getting started can be found [here](https://forgemia.inra.fr/bbatardiere/pyplnmodels/-/raw/dev/Getting_started.ipynb?inline=false). If you need just a quick view of the package, see the quickstart next. -## Installation - -**pyPLNmodels** is available on -[pypi](https://pypi.org/project/pyPLNmodels/). The development -version is available on [GitHub](https://github.com/PLN-team/pyPLNmodels). - -### Package installation - -``` -pip install pyPLNmodels -``` - -## Usage and main fitting functions +## âš¡ï¸ Quickstart The package comes with an ecological data set to present the functionality ``` @@ -61,7 +49,24 @@ transformed_data = pln.transform() ``` -## References +## 🛠Installation + +**pyPLNmodels** is available on +[pypi](https://pypi.org/project/pyPLNmodels/). The development +version is available on [GitHub](https://github.com/PLN-team/pyPLNmodels). + +### Package installation + +``` +pip install pyPLNmodels +``` + +## 👠Contributing + +Feel free to contribute, but read the [CONTRIBUTING.md](https://forgemia.inra.fr/bbatardiere/pyplnmodels/-/blob/main/CONTRIBUTING.md) first. A public roadmap will be available soon. + + +## âš¡ï¸ Citations Please cite our work using the following references: - J. Chiquet, M. Mariadassou and S. Robin: Variational inference for diff --git a/docs/source/index.rst b/docs/source/index.rst index 98f3e0a6af4eb6aa794c9e81fc809c69e3115925..da418320ba1183085534bb905f4e7bdbb0cd07bc 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -16,6 +16,7 @@ API documentation ./plnpcacollection.rst ./plnpca.rst ./pln.rst + ./zipln.rst .. toctree:: :maxdepth: 1 diff --git a/docs/source/zipln.rst b/docs/source/zipln.rst new file mode 100644 index 0000000000000000000000000000000000000000..ae0e1e813f3db10bbf752050ce5444d293ada203 --- /dev/null +++ b/docs/source/zipln.rst @@ -0,0 +1,10 @@ + +ZIPln +=== + +.. autoclass:: pyPLNmodels.ZIPln + :members: + :inherited-members: + :special-members: __init__ + :undoc-members: + :show-inheritance: diff --git a/pyPLNmodels/__init__.py b/pyPLNmodels/__init__.py index e785b2881e493ace84a2f4c279d8c83b5ed05e77..6ed723c750d005cf75c8ba8e75b5ac883353ee30 100644 --- a/pyPLNmodels/__init__.py +++ b/pyPLNmodels/__init__.py @@ -1,4 +1,4 @@ -from .models import PlnPCAcollection, Pln, PlnPCA # pylint:disable=[C0114] +from .models import PlnPCAcollection, Pln, PlnPCA, ZIPln # pylint:disable=[C0114] from .oaks import load_oaks from .elbos import profiled_elbo_pln, elbo_plnpca, elbo_pln from ._utils import ( diff --git a/pyPLNmodels/_closed_forms.py b/pyPLNmodels/_closed_forms.py index b57e78505f72e43064a8088b93959d582a4fe60d..3524d48d524c5e4cd6be77c1233cd7231894ec1e 100644 --- a/pyPLNmodels/_closed_forms.py +++ b/pyPLNmodels/_closed_forms.py @@ -1,4 +1,5 @@ from typing import Optional +from ._utils import phi import torch # pylint:disable=[C0114] @@ -98,3 +99,17 @@ def _closed_formula_pi( """ poiss_param = torch.exp(offsets + latent_mean + 0.5 * torch.square(latent_sqrt_var)) return torch._sigmoid(poiss_param + torch.mm(exog, _coef_inflation)) * dirac + + +def _closed_formula_latent_prob(exog, coef, coef_infla, cov, dirac): + if exog is not None: + XB = exog @ coef + XB_zero = exog @ coef_infla + else: + XB_zero = 0 + XB = 0 + XB_zero = exog @ coef_infla + pi = torch.sigmoid(XB_zero) + diag = torch.diag(cov) + full_diag = diag.expand(exog.shape[0], -1) + return torch.sigmoid(XB_zero - torch.log(phi(XB, full_diag))) * dirac diff --git a/pyPLNmodels/_initialization.py b/pyPLNmodels/_initialization.py index e0c3f47e4dd02272eb8bc05b5e4ac15125bdeb1a..fe649fe056b469333575e8c353d90f1c824b4f46 100644 --- a/pyPLNmodels/_initialization.py +++ b/pyPLNmodels/_initialization.py @@ -2,6 +2,11 @@ import torch import math from typing import Optional from ._utils import _log_stirling +import time +from sklearn.decomposition import PCA +import seaborn as sns +import matplotlib.pyplot as plt +import numpy as np if torch.cuda.is_available(): DEVICE = torch.device("cuda") @@ -9,9 +14,7 @@ else: DEVICE = torch.device("cpu") -def _init_covariance( - endog: torch.Tensor, exog: torch.Tensor, coef: torch.Tensor -) -> torch.Tensor: +def _init_covariance(endog: torch.Tensor, exog: torch.Tensor) -> torch.Tensor: """ Initialization for the covariance for the Pln model. Take the log of endog (careful when endog=0), and computes the Maximum Likelihood @@ -40,9 +43,7 @@ def _init_covariance( return sigma_hat -def _init_components( - endog: torch.Tensor, exog: torch.Tensor, coef: torch.Tensor, rank: int -) -> torch.Tensor: +def _init_components(endog: torch.Tensor, rank: int) -> torch.Tensor: """ Initialization for components for the Pln model. Get a first guess for covariance that is easier to estimate and then takes the rank largest eigenvectors to get components. @@ -51,12 +52,6 @@ def _init_components( ---------- endog : torch.Tensor Samples with size (n,p) - offsets : torch.Tensor - Offset, size (n,p) - exog : torch.Tensor - Covariates, size (n,d) - coef : torch.Tensor - Coefficient of size (d,p) rank : int The dimension of the latent space, i.e. the reduced dimension. @@ -65,9 +60,11 @@ def _init_components( torch.Tensor Initialization of components of size (p,rank) """ - sigma_hat = _init_covariance(endog, exog, coef).detach() - components = _components_from_covariance(sigma_hat, rank) - return components + log_y = torch.log(endog + (endog == 0) * math.exp(-2)) + pca = PCA(n_components=rank) + pca.fit(log_y.detach().cpu()) + pca_comp = pca.components_.T * np.sqrt(pca.explained_variance_) + return torch.from_numpy(pca_comp).to(DEVICE) def _init_latent_mean( @@ -102,7 +99,7 @@ def _init_latent_mean( The learning rate of the optimizer. Default is 0.01. eps : float, optional The tolerance. The algorithm will stop as soon as the criterion is lower than the tolerance. - Default is 7e-3. + Default is 7e-1. Returns ------- diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py index 805c9dca4e67fe0e318160597efecfc041cef730..1cb9d2cd2bcef9dbf16321f41d4a6ecefd258d8d 100644 --- a/pyPLNmodels/_utils.py +++ b/pyPLNmodels/_utils.py @@ -23,8 +23,11 @@ else: DEVICE = torch.device("cpu") -class _PlotArgs: - def __init__(self, window: int): +BETA = 0.03 + + +class _CriterionArgs: + def __init__(self): """ Initialize the PlotArgs class. @@ -33,10 +36,34 @@ class _PlotArgs: window : int The size of the window for computing the criterion. """ - self.window = window self.running_times = [] - self.criterions = [1] * window # the first window criterion won't be computed. self._elbos_list = [] + self.cumulative_elbo_list = [0] + self.new_derivative = 0 + self.normalized_elbo_list = [] + self.criterion_list = [1] + self.criterion = 1 + + def update_criterion(self, elbo, running_time): + self._elbos_list.append(elbo) + self.running_times.append(running_time) + self.cumulative_elbo_list.append(self.cumulative_elbo + elbo) + self.normalized_elbo_list.append(-elbo / self.cumulative_elbo_list[-1]) + if self.iteration_number > 1: + current_derivative = np.abs( + (self.normalized_elbo_list[-2] - self.normalized_elbo_list[-1]) + / (self.running_times[-2] - self.running_times[-1]) + ) + old_derivative = self.new_derivative + self.new_derivative = ( + self.new_derivative * (1 - BETA) + current_derivative * BETA + ) + current_hessian = np.abs( + (self.new_derivative - old_derivative) + / (self.running_times[-2] - self.running_times[-1]) + ) + self.criterion = self.criterion * (1 - BETA) + current_hessian * BETA + self.criterion_list.append(self.criterion) @property def iteration_number(self) -> int: @@ -50,6 +77,10 @@ class _PlotArgs: """ return len(self._elbos_list) + @property + def cumulative_elbo(self): + return self.cumulative_elbo_list[-1] + def _show_loss(self, ax=None): """ Show the loss of the model (i.e. the negative ELBO). @@ -80,8 +111,8 @@ class _PlotArgs: """ ax = plt.gca() if ax is None else ax ax.plot( - self.running_times[self.window :], - self.criterions[self.window :], + self.running_times, + self.criterion_list, label="Delta", ) ax.set_yscale("log") @@ -170,7 +201,10 @@ def _log_stirling(integer: torch.Tensor) -> torch.Tensor: def _trunc_log(tens: torch.Tensor, eps: float = 1e-16) -> torch.Tensor: - integer = torch.min(torch.max(tens, torch.tensor([eps])), torch.tensor([1 - eps])) + integer = torch.min( + torch.max(tens, torch.tensor([eps]).to(DEVICE)), + torch.tensor([1 - eps]).to(DEVICE), + ) return torch.log(integer) @@ -306,12 +340,12 @@ def _format_model_param( exog = _format_data(exog) if add_const is True: if exog is None: - exog = torch.ones(endog.shape[0], 1) + exog = torch.ones(endog.shape[0], 1).to(DEVICE) else: if _has_null_variance(exog) is False: exog = torch.concat( (exog, torch.ones(endog.shape[0]).unsqueeze(1)), dim=1 - ) + ).to(DEVICE) if offsets is None: if offsets_formula == "logsum": print("Setting the offsets as the log of the sum of endog") @@ -463,8 +497,12 @@ def _get_simulation_components(dim: int, rank: int) -> torch.Tensor: return components.to("cpu") -def _get_simulation_coef_cov_offsets( - n_samples: int, nb_cov: int, dim: int, add_const: bool +def _get_simulation_coef_cov_offsets_coefzi( + n_samples: int, + nb_cov: int, + dim: int, + add_const: bool, + zero_inflated: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Get offsets, covariance coefficients with right shapes. @@ -482,6 +520,8 @@ def _get_simulation_coef_cov_offsets( Dimension required of the data. add_const : bool, optional If True, will add a vector of ones in the exog. + zero_inflated : bool + If True, will return a zero_inflated coefficient. Returns ------- @@ -506,14 +546,23 @@ def _get_simulation_coef_cov_offsets( if add_const is True: exog = torch.cat((exog, torch.ones(n_samples, 1)), axis=1) if exog is None: + if zero_inflated is True: + msg = "Can not instantiate a zero inflate model without covariates." + msg += " Please give at least an intercept by setting add_const to True" + raise ValueError(msg) coef = None + coef_inflation = None else: coef = torch.randn(exog.shape[1], dim, device="cpu") + if zero_inflated is True: + coef_inflation = torch.randn(exog.shape[1], dim, device="cpu") + else: + coef_inflation = None offsets = torch.randint( low=0, high=2, size=(n_samples, dim), dtype=torch.float64, device="cpu" ) torch.random.set_rng_state(prev_state) - return coef, exog, offsets + return coef, exog, offsets, coef_inflation class PlnParameters: @@ -524,7 +573,7 @@ class PlnParameters: coef: Union[torch.Tensor, np.ndarray, pd.DataFrame], exog: Union[torch.Tensor, np.ndarray, pd.DataFrame], offsets: Union[torch.Tensor, np.ndarray, pd.DataFrame], - coef_inflation=None, + coef_inflation: Union[torch.Tensor, np.ndarray, pd.DataFrame, None] = None, ): """ Instantiate all the needed parameters to sample from the PLN model. @@ -539,9 +588,8 @@ class PlnParameters: Covariates, size (n, d) or None offsets : : Union[torch.Tensor, np.ndarray, pd.DataFrame](keyword-only) Offset, size (n, p) - _coef_inflation : : Union[torch.Tensor, np.ndarray, pd.DataFrame] or None, optional(keyword-only) + coef_inflation : Union[torch.Tensor, np.ndarray, pd.DataFrame, None], optional(keyword-only) Coefficient for zero-inflation model, size (d, p) or None. Default is None. - """ self._components = _format_data(components) self._coef = _format_data(coef) @@ -682,6 +730,7 @@ def get_simulation_parameters( nb_cov: int = 1, rank: int = 5, add_const: bool = True, + zero_inflated: bool = False, ) -> PlnParameters: """ Generate simulation parameters for a Poisson-lognormal model. @@ -700,18 +749,26 @@ def get_simulation_parameters( The rank of the data components, by default 5. add_const : bool, optional(keyword-only) If True, will add a vector of ones in the exog. + zero_inflated : bool, optional(keyword-only) + If True, the model will be zero inflated. + Default is False. Returns ------- PlnParameters The generated simulation parameters. - """ - coef, exog, offsets = _get_simulation_coef_cov_offsets( - n_samples, nb_cov, dim, add_const + coef, exog, offsets, coef_inflation = _get_simulation_coef_cov_offsets_coefzi( + n_samples, nb_cov, dim, add_const, zero_inflated ) components = _get_simulation_components(dim, rank) - return PlnParameters(components=components, coef=coef, exog=exog, offsets=offsets) + return PlnParameters( + components=components, + coef=coef, + exog=exog, + offsets=offsets, + coef_inflation=coef_inflation, + ) def get_simulated_count_data( @@ -722,6 +779,7 @@ def get_simulated_count_data( nb_cov: int = 1, return_true_param: bool = False, add_const: bool = True, + zero_inflated=False, seed: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -741,19 +799,45 @@ def get_simulated_count_data( Number of exog, by default 1. return_true_param : bool, optional(keyword-only) Whether to return the true parameters of the model, by default False. + zero_inflated: bool, optional(keyword-only) + Whether to use a zero inflated model or not. + Default to False. seed : int, optional(keyword-only) Seed value for random number generation, by default 0. Returns ------- - Tuple[torch.Tensor, torch.Tensor, torch.Tensor] - Tuple containing endog, exog, and offsets. + if return_true_param is False: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + Tuple containing endog, exog, and offsets. + else: + if zero_inflated is True: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + Tuple containing endog, exog, offsets, covariance, coef, coef_inflation . + else: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + Tuple containing endog, exog, offsets, covariance, coef. + """ pln_param = get_simulation_parameters( - n_samples=n_samples, dim=dim, nb_cov=nb_cov, rank=rank, add_const=add_const + n_samples=n_samples, + dim=dim, + nb_cov=nb_cov, + rank=rank, + add_const=add_const, + zero_inflated=zero_inflated, ) endog = sample_pln(pln_param, seed=seed, return_latent=False) if return_true_param is True: + if zero_inflated is True: + return ( + endog, + pln_param.exog, + pln_param.offsets, + pln_param.covariance, + pln_param.coef, + pln_param.coef_inflation, + ) return ( endog, pln_param.exog, @@ -761,7 +845,7 @@ def get_simulated_count_data( pln_param.covariance, pln_param.coef, ) - return pln_param.endog, pln_param.cov, pln_param.offsets + return endog, pln_param.exog, pln_param.offsets def get_real_count_data( @@ -848,6 +932,10 @@ def _extract_data_from_formula( A tuple containing the extracted endog, exog, and offsets. """ + # dmatrices can not deal with GPU matrices + for key, matrix in data.items(): + if isinstance(matrix, torch.Tensor): + data[key] = matrix.cpu() dmatrix = dmatrices(formula, data=data) endog = dmatrix[0] exog = dmatrix[1] @@ -1005,3 +1093,60 @@ def _add_doc(parent_class, *, params=None, example=None, returns=None, see_also= return fun return wrapper + + +def pf_lambert(x, y): + return x - (1 - (y * torch.exp(-x) + 1) / (x + 1)) + + +def lambert(y, nb_pf=10): + x = torch.log(1 + y) + for _ in range(nb_pf): + x = pf_lambert(x, y) + return x + + +def d_varpsi_x1(mu, sigma2): + W = lambert(sigma2 * torch.exp(mu)) + first = phi(mu, sigma2) + third = 1 / sigma2 + 1 / 2 * 1 / ((1 + W) ** 2) + return -first * W * third + + +def phi(mu, sigma2): + y = sigma2 * torch.exp(mu) + lamby = lambert(y) + log_num = -1 / (2 * sigma2) * (lamby**2 + 2 * lamby) + return torch.exp(log_num) / torch.sqrt(1 + lamby) + + +def d_varpsi_x2(mu, sigma2): + first = d_varpsi_x1(mu, sigma2) / sigma2 + W = lambert(sigma2 * torch.exp(mu)) + second = (W**2 + 2 * W) / 2 / (sigma2**2) * phi(mu, sigma2) + return first + second + + +def d_h_x2(a, x, y, dirac): + rho = torch.sigmoid(a - torch.log(phi(x, y))) * dirac + rho_prime = rho * (1 - rho) + return -rho_prime * d_varpsi_x1(x, y) / phi(x, y) + + +def d_h_x3(a, x, y, dirac): + rho = torch.sigmoid(a - torch.log(phi(x, y))) * dirac + rho_prime = rho * (1 - rho) + return -rho_prime * d_varpsi_x2(x, y) / phi(x, y) + + +def vec_to_mat(C, p, q): + c = torch.zeros(p, q) + c[torch.tril_indices(p, q, offset=0).tolist()] = C + # c = C.reshape(p,q) + return c + + +def mat_to_vec(matc, p, q): + tril = torch.tril(matc) + # tril = matc.reshape(-1,1).squeeze() + return tril[torch.tril_indices(p, q, offset=0).tolist()] diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py index 6dcda36199eaf5bebcb6eb8f8486f3543a5c3a59..73e7702866ef9b2960412b5cd2a546212c84644e 100644 --- a/pyPLNmodels/elbos.py +++ b/pyPLNmodels/elbos.py @@ -5,63 +5,6 @@ from ._closed_forms import _closed_formula_covariance, _closed_formula_coef from typing import Optional -def elbo_pln( - endog: torch.Tensor, - offsets: torch.Tensor, - exog: Optional[torch.Tensor], - latent_mean: torch.Tensor, - latent_sqrt_var: torch.Tensor, - covariance: torch.Tensor, - coef: torch.Tensor, -) -> torch.Tensor: - """ - Compute the ELBO (Evidence Lower Bound) for the Pln model. - - Parameters: - ---------- - endog : torch.Tensor - Counts with size (n, p). - offsets : torch.Tensor - Offset with size (n, p). - exog : torch.Tensor, optional - Covariates with size (n, d). - latent_mean : torch.Tensor - Variational parameter with size (n, p). - latent_sqrt_var : torch.Tensor - Variational parameter with size (n, p). - covariance : torch.Tensor - Model parameter with size (p, p). - coef : torch.Tensor - Model parameter with size (d, p). - - Returns: - ------- - torch.Tensor - The ELBO (Evidence Lower Bound), of size one. - """ - n_samples, dim = endog.shape - s_rond_s = torch.square(latent_sqrt_var) - offsets_plus_m = offsets + latent_mean - if exog is None: - XB = torch.zeros_like(endog) - else: - XB = exog @ coef - m_minus_xb = latent_mean - XB - d_plus_minus_xb2 = ( - torch.diag(torch.sum(s_rond_s, dim=0)) + m_minus_xb.T @ m_minus_xb - ) - elbo = -0.5 * n_samples * torch.logdet(covariance) - elbo += torch.sum( - endog * offsets_plus_m - - 0.5 * torch.exp(offsets_plus_m + s_rond_s) - + 0.5 * torch.log(s_rond_s) - ) - elbo -= 0.5 * torch.trace(torch.inverse(covariance) @ d_plus_minus_xb2) - elbo -= torch.sum(_log_stirling(endog)) - elbo += 0.5 * n_samples * dim - return elbo / n_samples - - def profiled_elbo_pln( endog: torch.Tensor, exog: torch.Tensor, @@ -172,6 +115,79 @@ def elbo_plnpca( ) / n_samples +def log1pexp(x): + # more stable version of log(1 + exp(x)) + return torch.where(x < 50, torch.log1p(torch.exp(x)), x) + + +def elbo_pln( + endog: torch.Tensor, + exog: Optional[torch.Tensor], + offsets: torch.Tensor, + latent_mean: torch.Tensor, + latent_sqrt_var: torch.Tensor, + covariance: torch.Tensor, + coef: torch.Tensor, +) -> torch.Tensor: + """ + Compute the ELBO (Evidence Lower Bound) for the Pln model. + + Parameters: + ---------- + endog : torch.Tensor + Counts with size (n, p). + offsets : torch.Tensor + Offset with size (n, p). + exog : torch.Tensor, optional + Covariates with size (n, d). + latent_mean : torch.Tensor + Variational parameter with size (n, p). + latent_sqrt_var : torch.Tensor + Variational parameter with size (n, p). + covariance : torch.Tensor + Model parameter with size (p, p). + coef : torch.Tensor + Model parameter with size (d, p). + + Returns: + ------- + torch.Tensor + The ELBO (Evidence Lower Bound), of size one. + """ + n_samples, dim = endog.shape + s_rond_s = torch.square(latent_sqrt_var) + offsets_plus_m = offsets + latent_mean + Omega = torch.inverse(covariance) + if exog is None: + XB = torch.zeros_like(endog) + else: + XB = exog @ coef + # print('XB:', XB) + m_minus_xb = latent_mean - XB + m_moins_xb_outer = torch.mm(m_minus_xb.T, m_minus_xb) + A = torch.exp(offsets_plus_m + s_rond_s / 2) + first_a = torch.sum(endog * offsets_plus_m) + sec_a = -torch.sum(A) + third_a = -torch.sum(_log_stirling(endog)) + a = first_a + sec_a + third_a + diag = torch.diag(torch.sum(s_rond_s, dim=0)) + elbo = torch.clone(a) + b = -0.5 * n_samples * torch.logdet(covariance) + torch.sum( + -1 / 2 * Omega * m_moins_xb_outer + ) + elbo += b + d = n_samples * dim / 2 + torch.sum(+0.5 * torch.log(s_rond_s)) + elbo += d + f = -0.5 * torch.trace(torch.inverse(covariance) @ diag) + elbo += f + # print("a pln", a) + # print("b pln", b) + # print("d pln", d) + # print("f pln", f) + return elbo # / n_samples + + +## pb with trunc_log ## should rename some variables so that is is clearer when we see the formula def elbo_zi_pln( endog, @@ -179,13 +195,13 @@ def elbo_zi_pln( offsets, latent_mean, latent_sqrt_var, - pi, - covariance, + latent_prob, + components, coef, - _coef_inflation, + coef_inflation, dirac, ): - """Compute the ELBO (Evidence LOwer Bound) for the Zero Inflated Pln model. + """Compute the ELBO (Evidence LOwer Bound) for the Zero Inflated PLN model. See the doc for more details on the computation. Args: @@ -193,45 +209,66 @@ def elbo_zi_pln( 0: torch.tensor. Offset, size (n,p) exog: torch.tensor. Covariates, size (n,d) latent_mean: torch.tensor. Variational parameter with size (n,p) - latent_sqrt_var: torch.tensor. Variational parameter with size (n,p) + latent_var: torch.tensor. Variational parameter with size (n,p) pi: torch.tensor. Variational parameter with size (n,p) covariance: torch.tensor. Model parameter with size (p,p) coef: torch.tensor. Model parameter with size (d,p) - _coef_inflation: torch.tensor. Model parameter with size (d,p) + coef_inflation: torch.tensor. Model parameter with size (d,p) Returns: torch.tensor of size 1 with a gradient. """ - if torch.norm(pi * dirac - pi) > 0.0001: - print("Bug") - return False - n_samples = endog.shape[0] - dim = endog.shape[1] - s_rond_s = torch.square(latent_sqrt_var) - offsets_plus_m = offsets + latent_mean - m_minus_xb = latent_mean - exog @ coef - x_coef_inflation = exog @ _coef_inflation - elbo = torch.sum( - (1 - pi) - * ( - endog @ offsets_plus_m - - torch.exp(offsets_plus_m + s_rond_s / 2) - - _log_stirling(endog), - ) - + pi + covariance = components @ (components.T) + if torch.norm(latent_prob * dirac - latent_prob) > 0.00000001: + raise RuntimeError("latent_prob error") + n_samples, dim = endog.shape + s_rond_s = torch.multiply(latent_sqrt_var, latent_sqrt_var) + o_plus_m = offsets + latent_mean + if exog is None: + XB = torch.zeros_like(endog) + x_coef_inflation = torch.zeros_like(endog) + else: + XB = exog @ coef + x_coef_inflation = exog @ coef_inflation + + m_minus_xb = latent_mean - XB + + A = torch.exp(o_plus_m + s_rond_s / 2) + inside_a = torch.multiply( + 1 - latent_prob, torch.multiply(endog, o_plus_m) - A - _log_stirling(endog) + ) + Omega = torch.inverse(covariance) + m_moins_xb_outer = torch.mm(m_minus_xb.T, m_minus_xb) + un_moins_rho = 1 - latent_prob + un_moins_rho_m_moins_xb = un_moins_rho * m_minus_xb + un_moins_rho_m_moins_xb_outer = un_moins_rho_m_moins_xb.T @ un_moins_rho_m_moins_xb + inside_b = -1 / 2 * Omega * un_moins_rho_m_moins_xb_outer + inside_c = torch.multiply(latent_prob, x_coef_inflation) - torch.log( + 1 + torch.exp(x_coef_inflation) ) - elbo -= torch.sum(pi * _trunc_log(pi) + (1 - pi) * _trunc_log(1 - pi)) - elbo += torch.sum( - pi * x_coef_inflation - torch.log(1 + torch.exp(x_coef_inflation)) + log_diag = torch.log(torch.diag(covariance)) + log_S_term = torch.sum( + torch.multiply(1 - latent_prob, torch.log(torch.abs(latent_sqrt_var))), axis=0 ) + y = torch.sum(latent_prob, axis=0) + covariance_term = 1 / 2 * torch.log(torch.diag(covariance)) * y + inside_d = covariance_term + log_S_term - elbo -= 0.5 * torch.trace( - torch.mm( - torch.inverse(covariance), - torch.diag(torch.sum(s_rond_s, dim=0)) + m_minus_xb.T @ m_minus_xb, - ) + inside_e = -torch.multiply(latent_prob, _trunc_log(latent_prob)) - torch.multiply( + 1 - latent_prob, _trunc_log(1 - latent_prob) + ) + sum_un_moins_rho_s2 = torch.sum(torch.multiply(1 - latent_prob, s_rond_s), axis=0) + diag_sig_sum_rho = torch.multiply( + torch.diag(covariance), torch.sum(latent_prob, axis=0) ) - elbo += 0.5 * n_samples * torch.log(torch.det(covariance)) - elbo += 0.5 * n_samples * dim - elbo += 0.5 * torch.sum(torch.log(s_rond_s)) - return elbo + new = torch.sum(latent_prob * un_moins_rho * (m_minus_xb**2), axis=0) + K = sum_un_moins_rho_s2 + diag_sig_sum_rho + new + inside_f = -1 / 2 * torch.diag(Omega) * K + first = torch.sum(inside_a + inside_c + inside_e) + second = torch.sum(inside_b) + _, logdet = torch.slogdet(components) + second -= n_samples * logdet + third = torch.sum(inside_d + inside_f) + third += n_samples * dim / 2 + res = first + second + third + return res diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py index 9c2024ee50b0a0706b23f17837bdbf3c552710b6..d5854adfc3614003fb84c8a5ea974d7ffc9d1b38 100644 --- a/pyPLNmodels/models.py +++ b/pyPLNmodels/models.py @@ -2,7 +2,7 @@ import time from abc import ABC, abstractmethod import warnings import os -from typing import Optional, Dict, List, Type, Any, Iterable, Union +from typing import Optional, Dict, List, Type, Any, Iterable, Union, Literal import pandas as pd import torch @@ -18,11 +18,11 @@ from scipy import stats from ._closed_forms import ( _closed_formula_coef, _closed_formula_covariance, - _closed_formula_pi, + _closed_formula_latent_prob, ) from .elbos import elbo_plnpca, elbo_zi_pln, profiled_elbo_pln from ._utils import ( - _PlotArgs, + _CriterionArgs, _format_data, _nice_string_of_dict, _plot_ellipse, @@ -32,6 +32,8 @@ from ._utils import ( _array2tensor, _handle_data, _add_doc, + vec_to_mat, + mat_to_vec, ) from ._initialization import ( @@ -57,7 +59,6 @@ class _model(ABC): Base class for all the Pln models. Should be inherited. """ - _WINDOW: int = 15 _endog: torch.Tensor _exog: torch.Tensor _offsets: torch.Tensor @@ -65,6 +66,7 @@ class _model(ABC): _beginning_time: float _latent_sqrt_var: torch.Tensor _latent_mean: torch.Tensor + _batch_size: int = None def __init__( self, @@ -107,9 +109,10 @@ class _model(ABC): endog, exog, offsets, offsets_formula, take_log_offsets, add_const ) self._fitted = False - self._plotargs = _PlotArgs(self._WINDOW) + self._criterion_args = _CriterionArgs() if dict_initialization is not None: self._set_init_parameters(dict_initialization) + self._dirac = self._endog == 0 @classmethod def from_formula( @@ -160,12 +163,35 @@ class _model(ABC): """ if "coef" not in dict_initialization.keys(): print("No coef is initialized.") - self.coef = None + dict_initialization["coef"] = None + if self._NAME == "Pln": + del dict_initialization["covariance"] + del dict_initialization["coef"] for key, array in dict_initialization.items(): array = _format_data(array) setattr(self, key, array) self._fitted = True + @property + def batch_size(self) -> int: + """ + The batch size of the model. Should not be greater than the number of samples. + """ + if self._batch_size is None: + return self.n_samples + return self._batch_size + + @property + def _current_batch_size(self) -> int: + return self._exog_b.shape[0] + + @batch_size.setter + def batch_size(self, batch_size: int): + """ + Setter for the batch size. Should be an integer not greater than the number of samples. + """ + self._batch_size = self._handle_batch_size(batch_size) + @property def fitted(self) -> bool: """ @@ -210,12 +236,31 @@ class _model(ABC): y = proj_variables[:, 1] sns.scatterplot(x=x, y=y, hue=colors, ax=ax) if show_cov is True: - sk_components = torch.from_numpy(pca.components_) - covariances = self._get_pca_low_dim_covariances(sk_components).detach() + sk_components = torch.from_numpy(pca.components_).to(DEVICE) + covariances = ( + self._get_pca_low_dim_covariances(sk_components).cpu().detach() + ) for i in range(covariances.shape[0]): _plot_ellipse(x[i], y[i], cov=covariances[i], ax=ax) return ax + def _project_parameters(self): + pass + + def _handle_batch_size(self, batch_size): + if batch_size is None: + if hasattr(self, "batch_size"): + batch_size = self.batch_size + else: + batch_size = self.n_samples + if batch_size > self.n_samples: + raise ValueError( + f"batch_size ({batch_size}) can not be greater than the number of samples ({self.n_samples})" + ) + elif isinstance(batch_size, int) is False: + raise ValueError(f"batch_size should be int, got {type(batch_size)}") + return batch_size + @property def nb_iteration_done(self) -> int: """ @@ -226,7 +271,7 @@ class _model(ABC): int The number of iterations done. """ - return len(self._plotargs._elbos_list) + return len(self._criterion_args._elbos_list) * self.nb_batches @property def n_samples(self) -> int: @@ -280,20 +325,6 @@ class _model(ABC): self._coef = None self._coef = torch.randn((self.nb_cov, self.dim), device=DEVICE) - @abstractmethod - def _random_init_model_parameters(self): - """ - Abstract method to randomly initialize model parameters. - """ - pass - - @abstractmethod - def _random_init_latent_parameters(self): - """ - Abstract method to randomly initialize latent parameters. - """ - pass - def _smart_init_latent_parameters(self): """ Initialize latent parameters smartly. @@ -320,32 +351,20 @@ class _model(ABC): def _put_parameters_to_device(self): """ - Move parameters to the device. + Move parameters to the cGPU device if present. """ for parameter in self._list_of_parameters_needing_gradient: parameter.requires_grad_(True) - @property - def _list_of_parameters_needing_gradient(self): - """ - A list containing all the parameters that need to be upgraded via a gradient step. - - Returns - ------- - List[torch.Tensor] - List of parameters needing gradient. - """ - ... - def fit( self, nb_max_iteration: int = 50000, *, lr: float = 0.01, - class_optimizer: torch.optim.Optimizer = torch.optim.Rprop, tol: float = 1e-3, do_smart_init: bool = True, verbose: bool = False, + batch_size=None, ): """ Fit the model. The lower tol, the more accurate the model. @@ -356,50 +375,126 @@ class _model(ABC): The maximum number of iterations. Defaults to 50000. lr : float, optional(keyword-only) The learning rate. Defaults to 0.01. - class_optimizer : torch.optim.Optimizer, optional - The optimizer class. Defaults to torch.optim.Rprop. tol : float, optional(keyword-only) - The tolerance for convergence. Defaults to 1e-3. + The tolerance for convergence. Defaults to 1e-8. do_smart_init : bool, optional(keyword-only) Whether to perform smart initialization. Defaults to True. verbose : bool, optional(keyword-only) Whether to print training progress. Defaults to False. + batch_size: int, optional(keyword-only) + The batch size when optimizing the elbo. If None, + batch gradient descent will be performed (i.e. batch_size = n_samples). + Raises + ------ + ValueError + If the batch_size is greater than the number of samples, or not int. """ - self._pring_beginning_message() + self._print_beginning_message() self._beginning_time = time.time() - + self._batch_size = self._handle_batch_size(batch_size) if self._fitted is False: self._init_parameters(do_smart_init) - elif len(self._plotargs.running_times) > 0: - self._beginning_time -= self._plotargs.running_times[-1] + elif len(self._criterion_args.running_times) > 0: + self._beginning_time -= self._criterion_args.running_times[-1] self._put_parameters_to_device() - self.optim = class_optimizer(self._list_of_parameters_needing_gradient, lr=lr) + self._handle_optimizer(lr) stop_condition = False while self.nb_iteration_done < nb_max_iteration and not stop_condition: loss = self._trainstep() - criterion = self._compute_criterion_and_update_plotargs(loss, tol) + criterion = self._update_criterion_args(loss) if abs(criterion) < tol: stop_condition = True - if verbose and self.nb_iteration_done % 50 == 0: + if verbose and self.nb_iteration_done % 50 == 1: self._print_stats() self._print_end_of_fitting_message(stop_condition, tol) self._fitted = True + def _handle_optimizer(self, lr): + if self.batch_size < self.n_samples: + self.optim = torch.optim.Adam( + self._list_of_parameters_needing_gradient, lr=lr + ) + else: + self.optim = torch.optim.Rprop( + self._list_of_parameters_needing_gradient, lr=lr + ) + + def _get_batch(self, shuffle=False): + """Get the batches required to do a minibatch gradient ascent. + + Args: + batch_size: int. The batch size. Should be lower than n. + + Returns: A generator. Will generate n//batch_size + 1 batches of + size batch_size (except the last one since the rest of the + division is not always 0) + """ + indices = np.arange(self.n_samples) + if shuffle: + np.random.shuffle(indices) + for i in range(self._nb_full_batch): + batch = self._return_batch( + indices, i * self._batch_size, (i + 1) * self._batch_size + ) + yield batch + # Last batch + if self._last_batch_size != 0: + yield self._return_batch(indices, -self._last_batch_size, self.n_samples) + + def _return_batch(self, indices, beginning, end): + to_take = torch.tensor(indices[beginning:end]).to(DEVICE) + if self._exog is not None: + exog_b = torch.index_select(self._exog, 0, to_take) + else: + exog_b = None + return ( + torch.index_select(self._endog, 0, to_take), + exog_b, + torch.index_select(self._offsets, 0, to_take), + torch.index_select(self._latent_mean, 0, to_take), + torch.index_select(self._latent_sqrt_var, 0, to_take), + ) + + @property + def _nb_full_batch(self): + return self.n_samples // self.batch_size + + @property + def _last_batch_size(self): + return self.n_samples % self.batch_size + + @property + def nb_batches(self): + return self._nb_full_batch + (self._last_batch_size > 0) + def _trainstep(self): """ - Perform a single training step. + Perform a single pass of the data. Returns ------- torch.Tensor The loss value. """ - self.optim.zero_grad() - loss = -self.compute_elbo() - loss.backward() - self.optim.step() - self._update_closed_forms() - return loss + elbo = 0 + for batch in self._get_batch(shuffle=False): + self._extract_batch(batch) + self.optim.zero_grad() + loss = -self._compute_elbo_b() + if torch.sum(torch.isnan(loss)): + raise ValueError("test") + loss.backward() + elbo += loss.item() + self.optim.step() + self._project_parameters() + return elbo / self.nb_batches + + def _extract_batch(self, batch): + self._endog_b = batch[0] + self._exog_b = batch[1] + self._offsets_b = batch[2] + self._latent_mean_b = batch[3] + self._latent_sqrt_var_b = batch[4] def transform(self): """ @@ -439,7 +534,7 @@ class _model(ABC): def sk_PCA(self, n_components=None): """ - Perform PCA on the latent variables. + Perform the scikit-learn PCA on the latent variables. Parameters ---------- @@ -461,12 +556,13 @@ class _model(ABC): raise ValueError( f"You ask more components ({n_components}) than variables ({self.dim})" ) + latent_variables = self.transform() pca = PCA(n_components=n_components) - pca.fit(self.latent_variables.cpu()) + pca.fit(latent_variables.cpu()) return pca @property - def latent_var(self) -> torch.Tensor: + def latent_variance(self) -> torch.Tensor: """ Property representing the latent variance. @@ -503,9 +599,9 @@ class _model(ABC): f"You ask more components ({n_components}) than variables ({self.dim})" ) pca = self.sk_PCA(n_components=n_components) - proj_variables = pca.transform(self.latent_variables) + latent_variables = self.transform() + proj_variables = pca.transform(latent_variables) components = torch.from_numpy(pca.components_) - labels = { str(i): f"PC{i+1}: {np.round(pca.explained_variance_ratio_*100, 1)[i]}%" for i in range(n_components) @@ -563,7 +659,7 @@ class _model(ABC): n_components = 2 pca = self.sk_PCA(n_components=n_components) - variables = self.latent_variables + variables = self.transform() proj_variables = pca.transform(variables) ## the package is not correctly printing the variance ratio figure, correlation_matrix = plot_pca_correlation_graph( @@ -577,12 +673,16 @@ class _model(ABC): plt.show() @property - @abstractmethod - def latent_variables(self): + def _latent_var(self) -> torch.Tensor: """ - Abstract property representing the latent variables. + Property representing the latent variance. + + Returns + ------- + torch.Tensor + The latent variance tensor. """ - pass + return self._latent_sqrt_var**2 def _print_end_of_fitting_message(self, stop_condition: bool, tol: float): """ @@ -598,14 +698,14 @@ class _model(ABC): if stop_condition is True: print( f"Tolerance {tol} reached " - f"in {self._plotargs.iteration_number} iterations" + f"in {self._criterion_args.iteration_number} iterations" ) else: print( "Maximum number of iterations reached : ", - self._plotargs.iteration_number, + self._criterion_args.iteration_number, "last criterion = ", - np.round(self._plotargs.criterions[-1], 8), + np.round(self._criterion_args.criterion_list[-1], 8), ) def _print_stats(self): @@ -613,11 +713,11 @@ class _model(ABC): Print the training statistics. """ print("-------UPDATE-------") - print("Iteration number: ", self._plotargs.iteration_number) - print("Criterion: ", np.round(self._plotargs.criterions[-1], 8)) - print("ELBO:", np.round(self._plotargs._elbos_list[-1], 6)) + print("Iteration number: ", self._criterion_args.iteration_number) + print("Criterion: ", np.round(self._criterion_args.criterion_list[-1], 8)) + print("ELBO:", np.round(self._criterion_args._elbos_list[-1], 6)) - def _compute_criterion_and_update_plotargs(self, loss, tol): + def _update_criterion_args(self, loss): """ Compute the convergence criterion and update the plot arguments. @@ -625,38 +725,15 @@ class _model(ABC): ---------- loss : torch.Tensor The loss value. - tol : float - The tolerance for convergence. Returns ------- float The computed criterion. """ - self._plotargs._elbos_list.append(-loss.item()) - self._plotargs.running_times.append(time.time() - self._beginning_time) - if self._plotargs.iteration_number > self._WINDOW: - criterion = abs( - self._plotargs._elbos_list[-1] - - self._plotargs._elbos_list[-1 - self._WINDOW] - ) - self._plotargs.criterions.append(criterion) - return criterion - return tol - - def _update_closed_forms(self): - """ - Update closed-form expressions. - """ - pass - - @abstractmethod - def compute_elbo(self): - """ - Compute the Evidence Lower BOund (ELBO) that will be maximized - by pytorch. - """ - pass + current_running_time = time.time() - self._beginning_time + self._criterion_args.update_criterion(-loss, current_running_time) + return self._criterion_args.criterion def display_covariance(self, ax=None, savefig=False, name_file=""): """ @@ -743,8 +820,8 @@ class _model(ABC): if axes is None: _, axes = plt.subplots(1, nb_axes, figsize=(23, 5)) if self._fitted is True: - self._plotargs._show_loss(ax=axes[2]) - self._plotargs._show_stopping_criterion(ax=axes[1]) + self._criterion_args._show_loss(ax=axes[2]) + self._criterion_args._show_stopping_criterion(ax=axes[1]) self.display_covariance(ax=axes[0]) else: self.display_covariance(ax=axes) @@ -755,7 +832,7 @@ class _model(ABC): """ Property representing the list of ELBO values. """ - return self._plotargs._elbos_list + return self._criterion_args._elbos_list @property def loglike(self): @@ -769,8 +846,8 @@ class _model(ABC): """ if len(self._elbos_list) == 0: t0 = time.time() - self._plotargs._elbos_list.append(self.compute_elbo().item()) - self._plotargs.running_times.append(time.time() - t0) + self._criterion_args._elbos_list.append(self.compute_elbo().item()) + self._criterion_args.running_times.append(time.time() - t0) return self.n_samples * self._elbos_list[-1] @property @@ -797,33 +874,6 @@ class _model(ABC): """ return -self.loglike + self.number_of_parameters - @property - def latent_parameters(self): - """ - Property representing the latent parameters. - - Returns - ------- - dict - The dictionary of latent parameters. - """ - return { - "latent_sqrt_var": self.latent_sqrt_var, - "latent_mean": self.latent_mean, - } - - @property - def model_parameters(self): - """ - Property representing the model parameters. - - Returns - ------- - dict - The dictionary of model parameters. - """ - return {"coef": self.coef, "covariance": self.covariance} - @property def dict_data(self): """ @@ -920,31 +970,7 @@ class _model(ABC): raise ValueError( f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_mean.shape}" ) - self._latent_mean = latent_mean - - @latent_sqrt_var.setter - @_array2tensor - def latent_sqrt_var( - self, latent_sqrt_var: Union[torch.Tensor, np.ndarray, pd.DataFrame] - ): - """ - Setter for the latent variance property. - - Parameters - ---------- - latent_sqrt_var : Union[torch.Tensor, np.ndarray, pd.DataFrame] - The latent variance. - - Raises - ------ - ValueError - If the shape of the latent variance is incorrect. - """ - if latent_sqrt_var.shape != (self.n_samples, self.dim): - raise ValueError( - f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_sqrt_var.shape}" - ) - self._latent_sqrt_var = latent_sqrt_var + self._latent_mean = latent_mean.to(DEVICE) def _cpu_attribute_or_none(self, attribute_name): """ @@ -981,6 +1007,13 @@ class _model(ABC): os.makedirs(path, exist_ok=True) for key, value in self._dict_parameters.items(): filename = f"{path}/{key}.csv" + if key == "latent_prob": + if torch.max(value) > 1 or torch.min(value) < 0: + if ( + torch.norm(self.dirac * self.latent_prob - self.latent_prob) + > 0.0001 + ): + raise Exception("Error is here") if isinstance(value, torch.Tensor): pd.DataFrame(np.array(value.cpu().detach())).to_csv( filename, header=None, index=None @@ -1233,18 +1266,6 @@ class _model(ABC): """ return f"{self._NAME}_nbcov_{self.nb_cov}_dim_{self.dim}" - @property - def _path_to_directory(self): - """ - Property representing the path to the directory. - - Returns - ------- - str - The path to the directory. - """ - return "" - def plot_expected_vs_true(self, ax=None, colors=None): """ Plot the predicted value of the endog against the endog. @@ -1267,7 +1288,7 @@ class _model(ABC): raise RuntimeError("Please fit the model before.") if ax is None: ax = plt.gca() - predictions = self._endog_predictions().ravel().detach() + predictions = self._endog_predictions().ravel().cpu().detach() if colors is not None: colors = np.repeat(np.array(colors), repeats=self.dim).ravel() sns.scatterplot(x=self.endog.ravel(), y=predictions, hue=colors, ax=ax) @@ -1281,8 +1302,119 @@ class _model(ABC): ax.legend() return ax + def _print_beginning_message(self): + """ + Method for printing the beginning message. + """ + print(f"Fitting a {self._NAME} model with {self._description}") + + @property + @abstractmethod + def latent_variables(self) -> torch.Tensor: + """ + Property representing the latent variables. + + Returns + ------- + torch.Tensor + The latent variables of size (n_samples, dim). + """ + + @abstractmethod + def compute_elbo(self): + """ + Compute the Evidence Lower BOund (ELBO) that will be maximized + by pytorch. + + Returns + ------- + torch.Tensor + The computed ELBO. + """ + + @abstractmethod + def _compute_elbo_b(self): + """ + Compute the Evidence Lower BOund (ELBO) for the current mini-batch. + Returns + ------- + torch.Tensor + The computed ELBO on the current batch. + """ + + @abstractmethod + def _random_init_model_parameters(self): + """ + Abstract method to randomly initialize model parameters. + """ + + @abstractmethod + def _random_init_latent_parameters(self): + """ + Abstract method to randomly initialize latent parameters. + """ + + @abstractmethod + def _smart_init_latent_parameters(self): + """ + Method for smartly initializing the latent parameters. + """ + + @abstractmethod + def _smart_init_model_parameters(self): + """ + Method for smartly initializing the model parameters. + """ + + @property + @abstractmethod + def _list_of_parameters_needing_gradient(self): + """ + A list containing all the parameters that need to be upgraded via a gradient step. + + Returns + ------- + List[torch.Tensor] + List of parameters needing gradient. + """ + + @property + @abstractmethod + def _description(self): + pass + + @property + @abstractmethod + def number_of_parameters(self): + """ + Number of parameters of the model. + """ + + @property + @abstractmethod + def model_parameters(self) -> Dict[str, torch.Tensor]: + """ + Property representing the model parameters. + + Returns + ------- + dict + The dictionary of model parameters. + """ + + @property + @abstractmethod + def latent_parameters(self) -> Dict[str, torch.Tensor]: + """ + Property representing the latent parameters. + + Returns + ------- + dict + The dictionary of latent parameters. + """ + -# need to do a good init for M and S class Pln(_model): """ Pln class. @@ -1371,15 +1503,12 @@ class Pln(_model): dict_initialization: Optional[Dict[str, torch.Tensor]] = None, take_log_offsets: bool = False, ): - endog, exog, offsets = _extract_data_from_formula(formula, data) - return cls( - endog, - exog=exog, - offsets=offsets, + return super().from_formula( + formula=formula, + data=data, offsets_formula=offsets_formula, dict_initialization=dict_initialization, take_log_offsets=take_log_offsets, - add_const=False, ) @_add_doc( @@ -1397,18 +1526,18 @@ class Pln(_model): nb_max_iteration: int = 50000, *, lr: float = 0.01, - class_optimizer: torch.optim.Optimizer = torch.optim.Rprop, tol: float = 1e-3, do_smart_init: bool = True, verbose: bool = False, + batch_size: int = None, ): super().fit( nb_max_iteration, lr=lr, - class_optimizer=class_optimizer, tol=tol, do_smart_init=do_smart_init, verbose=verbose, + batch_size=batch_size, ) @_add_doc( @@ -1546,42 +1675,18 @@ class Pln(_model): ---------- coef : Union[torch.Tensor, np.ndarray, pd.DataFrame] The regression coefficients of the gaussian latent variables. + Raises + ------ + AttributeError since you can not set the coef in the Pln model. """ + msg = "You can not set the coef in the Pln model." + warnings.warn(msg) def _endog_predictions(self): return torch.exp( self._offsets + self._latent_mean + 1 / 2 * self._latent_sqrt_var**2 ) - def _smart_init_latent_parameters(self): - """ - Method for smartly initializing the latent parameters. - """ - self._random_init_latent_parameters() - - def _random_init_latent_parameters(self): - """ - Method for randomly initializing the latent parameters. - """ - if not hasattr(self, "_latent_sqrt_var"): - self._latent_sqrt_var = ( - 1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE) - ) - if not hasattr(self, "_latent_mean"): - self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE) - - @property - def _list_of_parameters_needing_gradient(self): - """ - Property representing the list of parameters needing gradient. - - Returns - ------- - list - The list of parameters needing gradient. - """ - return [self._latent_mean, self._latent_sqrt_var] - def _get_max_components(self): """ Method for getting the maximum number of components. @@ -1593,44 +1698,6 @@ class Pln(_model): """ return self.dim - def compute_elbo(self): - """ - Method for computing the evidence lower bound (ELBO). - - Returns - ------- - torch.Tensor - The computed ELBO. - Examples - -------- - >>> from pyPLNmodels import Pln, get_real_count_data - >>> endog, labels = get_real_count_data(return_labels = True) - >>> pln = Pln(endog,add_const = True) - >>> pln.fit() - >>> elbo = pln.compute_elbo() - >>> print("elbo", elbo) - >>> print("loglike/n", pln.loglike/pln.n_samples) - """ - return profiled_elbo_pln( - self._endog, - self._exog, - self._offsets, - self._latent_mean, - self._latent_sqrt_var, - ) - - def _smart_init_model_parameters(self): - """ - Method for smartly initializing the model parameters. - """ - # no model parameters since we are doing a profiled ELBO - - def _random_init_model_parameters(self): - """ - Method for randomly initializing the model parameters. - """ - # no model parameters since we are doing a profiled ELBO - @property def _coef(self): """ @@ -1668,25 +1735,29 @@ class Pln(_model): covariances = components_var @ (sk_components.T.unsqueeze(0)) return covariances - def _pring_beginning_message(self): - """ - Method for printing the beginning message. + @_model.latent_sqrt_var.setter + @_array2tensor + def latent_sqrt_var( + self, latent_sqrt_var: Union[torch.Tensor, np.ndarray, pd.DataFrame] + ): """ - print(f"Fitting a Pln model with {self._description}") + Setter for the latent variance property. - @property - @_add_doc( - _model, - example=""" - >>> from pyPLNmodels import Pln, get_real_count_data - >>> endog, labels = get_real_count_data(return_labels = True) - >>> pln = Pln(endog,add_const = True) - >>> pln.fit() - >>> print(pln.latent_variables.shape) - """, - ) - def latent_variables(self): - return self.latent_mean.detach() + Parameters + ---------- + latent_sqrt_var : Union[torch.Tensor, np.ndarray, pd.DataFrame] + The latent variance. + + Raises + ------ + ValueError + If the shape of the latent variance is incorrect. + """ + if latent_sqrt_var.shape != (self.n_samples, self.dim): + raise ValueError( + f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_sqrt_var.shape}" + ) + self._latent_sqrt_var = latent_sqrt_var @property def number_of_parameters(self): @@ -1734,7 +1805,119 @@ class Pln(_model): covariance : torch.Tensor The covariance matrix. """ + warnings.warn("You can not set the covariance for the Pln model.") + + def _random_init_latent_sqrt_var(self): + if not hasattr(self, "_latent_sqrt_var"): + self._latent_sqrt_var = ( + 1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE) + ) + + @_add_doc(_model) + def _smart_init_model_parameters(self): + pass + # no model parameters since we are doing a profiled ELBO + + @_add_doc(_model) + def _random_init_model_parameters(self): + pass + # no model parameters since we are doing a profiled ELBO + + @property + @_add_doc(_model) + def _list_of_parameters_needing_gradient(self): + return [self._latent_mean, self._latent_sqrt_var] + + @property + @_add_doc(_model) + def model_parameters(self) -> Dict[str, torch.Tensor]: + return {"coef": self.coef, "covariance": self.covariance} + + @property + @_add_doc(_model) + def latent_parameters(self): + return { + "latent_sqrt_var": self.latent_sqrt_var, + "latent_mean": self.latent_mean, + } + + def _random_init_latent_sqrt_var(self): + if not hasattr(self, "_latent_sqrt_var"): + self._latent_sqrt_var = ( + 1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE) + ) + + @property + @_add_doc( + _model, + example=""" + >>> from pyPLNmodels import Pln, get_real_count_data + >>> endog, labels = get_real_count_data(return_labels = True) + >>> pln = Pln(endog,add_const = True) + >>> pln.fit() + >>> print(pln.latent_variables.shape) + """, + ) + def latent_variables(self): + return self.latent_mean.detach() + + @_add_doc( + _model, + example=""" + >>> from pyPLNmodels import Pln, get_real_count_data + >>> endog, labels = get_real_count_data(return_labels = True) + >>> pln = Pln(endog,add_const = True) + >>> pln.fit() + >>> elbo = pln.compute_elbo() + >>> print("elbo", elbo) + >>> print("loglike/n", pln.loglike/pln.n_samples) + """, + ) + def compute_elbo(self): + return profiled_elbo_pln( + self._endog, + self._exog, + self._offsets, + self._latent_mean, + self._latent_sqrt_var, + ) + + @_add_doc(_model) + def _compute_elbo_b(self): + return profiled_elbo_pln( + self._endog_b, + self._exog_b, + self._offsets_b, + self._latent_mean_b, + self._latent_sqrt_var_b, + ) + + @_add_doc(_model) + def _smart_init_model_parameters(self): pass + # no model parameters since we are doing a profiled ELBO + + @_add_doc(_model) + def _random_init_model_parameters(self): + pass + # no model parameters since we are doing a profiled ELBO + + @_add_doc(_model) + def _smart_init_latent_parameters(self): + self._random_init_latent_sqrt_var() + if not hasattr(self, "_latent_mean"): + self._latent_mean = torch.log(self._endog + (self._endog == 0)) + + @_add_doc(_model) + def _random_init_latent_parameters(self): + self._random_init_latent_sqrt_var() + if not hasattr(self, "_latent_mean"): + self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE) + + @property + @_add_doc(_model) + def _list_of_parameters_needing_gradient(self): + return [self._latent_mean, self._latent_sqrt_var] class PlnPCAcollection: @@ -1801,6 +1984,9 @@ class PlnPCAcollection: Whether to take the logarithm of offsets, by default False. add_const: bool, optional(keyword-only) Whether to add a column of one in the exog. Defaults to True. + batch_size: int, optional(keyword-only) + The batch size when optimizing the elbo. If None, + batch gradient descent will be performed (i.e. batch_size = n_samples). Returns ------- PlnPCAcollection @@ -1819,7 +2005,7 @@ class PlnPCAcollection: endog, exog, offsets, offsets_formula, take_log_offsets, add_const ) self._fitted = False - self._init_models(ranks, dict_of_dict_initialization) + self._init_models(ranks, dict_of_dict_initialization, add_const=add_const) @classmethod def from_formula( @@ -1851,6 +2037,7 @@ class PlnPCAcollection: The dictionary of initialization, by default None. take_log_offsets : bool, optional(keyword-only) Whether to take the logarithm of offsets, by default False. + Returns ------- PlnPCAcollection @@ -1890,6 +2077,18 @@ class PlnPCAcollection: """ return self[self.ranks[0]].exog + @property + def batch_size(self) -> torch.Tensor: + """ + Property representing the batch_size. + + Returns + ------- + torch.Tensor + The batch_size. + """ + return self[self.ranks[0]].batch_size + @property def endog(self) -> torch.Tensor: """ @@ -1964,6 +2163,19 @@ class PlnPCAcollection: for model in self.values(): model.endog = endog + @batch_size.setter + def batch_size(self, batch_size: int): + """ + Setter for the batch_size property. + + Parameters + ---------- + batch_size : int + The batch size. + """ + for model in self.values(): + model.batch_size = batch_size + @coef.setter @_array2tensor def coef(self, coef: Union[torch.Tensor, np.ndarray, pd.DataFrame]): @@ -2019,7 +2231,10 @@ class PlnPCAcollection: model.offsets = offsets def _init_models( - self, ranks: Iterable[int], dict_of_dict_initialization: Optional[dict] + self, + ranks: Iterable[int], + dict_of_dict_initialization: Optional[dict], + add_const: bool, ): """ Method for initializing the models. @@ -2043,6 +2258,7 @@ class PlnPCAcollection: offsets=self._offsets, rank=rank, dict_initialization=dict_initialization, + add_const=add_const, ) else: raise TypeError( @@ -2087,7 +2303,7 @@ class PlnPCAcollection: """ return [model.rank for model in self.values()] - def _pring_beginning_message(self) -> str: + def _print_beginning_message(self) -> str: """ Method for printing the beginning message. @@ -2127,10 +2343,10 @@ class PlnPCAcollection: nb_max_iteration: int = 50000, *, lr: float = 0.01, - class_optimizer: Type[torch.optim.Optimizer] = torch.optim.Rprop, tol: float = 1e-3, do_smart_init: bool = True, verbose: bool = False, + batch_size: int = None, ): """ Fit each model in the PlnPCAcollection. @@ -2141,25 +2357,30 @@ class PlnPCAcollection: The maximum number of iterations, by default 50000. lr : float, optional(keyword-only) The learning rate, by default 0.01. - class_optimizer : Type[torch.optim.Optimizer], optional(keyword-only) - The optimizer class, by default torch.optim.Rprop. tol : float, optional(keyword-only) - The tolerance, by default 1e-3. + The tolerance, by default 1e-8. do_smart_init : bool, optional(keyword-only) Whether to do smart initialization, by default True. verbose : bool, optional(keyword-only) Whether to print verbose output, by default False. + batch_size: int, optional(keyword-only) + The batch size when optimizing the elbo. If None, + batch gradient descent will be performed (i.e. batch_size = n_samples). + Raises + ------ + ValueError + If the batch_size is greater than the number of samples, or not int. """ - self._pring_beginning_message() + self._print_beginning_message() for i in range(len(self.values())): model = self[self.ranks[i]] model.fit( nb_max_iteration, lr=lr, - class_optimizer=class_optimizer, tol=tol, do_smart_init=do_smart_init, verbose=verbose, + batch_size=batch_size, ) if i < len(self.values()) - 1: next_model = self[self.ranks[i + 1]] @@ -2360,23 +2581,26 @@ class PlnPCAcollection: bic = self.BIC aic = self.AIC loglikes = self.loglikes - bic_color = "blue" - aic_color = "red" - loglikes_color = "orange" - plt.scatter(bic.keys(), bic.values(), label="BIC criterion", c=bic_color) - plt.plot(bic.keys(), bic.values(), c=bic_color) - plt.axvline(self.best_BIC_model_rank, c=bic_color, linestyle="dotted") - plt.scatter(aic.keys(), aic.values(), label="AIC criterion", c=aic_color) - plt.axvline(self.best_AIC_model_rank, c=aic_color, linestyle="dotted") - plt.plot(aic.keys(), aic.values(), c=aic_color) - plt.xticks(list(aic.keys())) - plt.scatter( - loglikes.keys(), - -np.array(list(loglikes.values())), - label="Negative log likelihood", - c=loglikes_color, - ) - plt.plot(loglikes.keys(), -np.array(list(loglikes.values())), c=loglikes_color) + colors = {"BIC": "blue", "AIC": "red", "Negative log likelihood": "orange"} + for criterion, values in zip( + ["BIC", "AIC", "Negative log likelihood"], [bic, aic, loglikes] + ): + plt.scatter( + values.keys(), + values.values(), + label=f"{criterion} criterion", + c=colors[criterion], + ) + plt.plot(values.keys(), values.values(), c=colors[criterion]) + if criterion == "BIC": + plt.axvline( + self.best_BIC_model_rank, c=colors[criterion], linestyle="dotted" + ) + elif criterion == "AIC": + plt.axvline( + self.best_AIC_model_rank, c=colors[criterion], linestyle="dotted" + ) + plt.xticks(list(values.keys())) plt.legend() plt.show() @@ -2522,7 +2746,7 @@ class PlnPCAcollection: return ".BIC, .AIC, .loglikes" -# Here, setting the value for each key in _dict_parameters +# Here, setting the value for each key _dict_parameters class PlnPCA(_model): """ PlnPCA object where the covariance has low rank. @@ -2574,7 +2798,7 @@ class PlnPCA(_model): ) def __init__( self, - endog: Optional[Union[torch.Tensor, np.ndarray, pd.DataFrame]], + endog: Union[torch.Tensor, np.ndarray, pd.DataFrame], *, exog: Optional[Union[torch.Tensor, np.ndarray, pd.DataFrame]] = None, offsets: Optional[Union[torch.Tensor, np.ndarray, pd.DataFrame]] = None, @@ -2651,18 +2875,18 @@ class PlnPCA(_model): nb_max_iteration: int = 50000, *, lr: float = 0.01, - class_optimizer: torch.optim.Optimizer = torch.optim.Rprop, tol: float = 1e-3, do_smart_init: bool = True, verbose: bool = False, + batch_size=None, ): super().fit( nb_max_iteration, lr=lr, - class_optimizer=class_optimizer, tol=tol, do_smart_init=do_smart_init, verbose=verbose, + batch_size=batch_size, ) @_add_doc( @@ -2748,19 +2972,6 @@ class PlnPCA(_model): variables_names=variables_names, indices_of_variables=indices_of_variables ) - def _check_if_rank_is_too_high(self): - """ - Check if the rank is too high and issue a warning if necessary. - """ - if self.dim < self.rank: - warning_string = ( - f"\nThe requested rank of approximation {self.rank} " - f"is greater than the number of variables {self.dim}. " - f"Setting rank to {self.dim}" - ) - warnings.warn(warning_string) - self._rank = self.dim - @property @_add_doc( _model, @@ -2776,42 +2987,21 @@ class PlnPCA(_model): def latent_mean(self) -> torch.Tensor: return self._cpu_attribute_or_none("_latent_mean") - @property - def latent_sqrt_var(self) -> torch.Tensor: - """ - Property representing the unsigned square root of the latent variance. - - Returns - ------- - torch.Tensor - The latent variance tensor. - """ - return self._cpu_attribute_or_none("_latent_sqrt_var") - - @property - def _latent_var(self) -> torch.Tensor: - """ - Property representing the latent variance. - - Returns - ------- - torch.Tensor - The latent variance tensor. - """ - return self._latent_sqrt_var**2 - def _endog_predictions(self): covariance_a_posteriori = torch.sum( (self._components**2).unsqueeze(0) - * (self.latent_sqrt_var**2).unsqueeze(1), + * (self._latent_sqrt_var**2).unsqueeze(1), axis=2, ) if self.exog is not None: - XB = self.exog @ self.coef + XB = self._exog @ self._coef else: XB = 0 return torch.exp( - self._offsets + XB + self.latent_variables + 1 / 2 * covariance_a_posteriori + self._offsets + + XB + + self.latent_variables.to(DEVICE) + + 1 / 2 * covariance_a_posteriori ) @latent_mean.setter @@ -2831,7 +3021,7 @@ class PlnPCA(_model): ) self._latent_mean = latent_mean - @latent_sqrt_var.setter + @_model.latent_sqrt_var.setter @_array2tensor def latent_sqrt_var(self, latent_sqrt_var: torch.Tensor): """ @@ -2912,129 +3102,29 @@ class PlnPCA(_model): """ return self._rank - def _pring_beginning_message(self): + @property + def number_of_parameters(self) -> int: """ - Print the beginning message when fitted. + Property representing the number of parameters. + + Returns + ------- + int + The number of parameters. """ - print("-" * NB_CHARACTERS_FOR_NICE_PLOT) - print(f"Fitting a PlnPCAcollection model with {self._rank} components") + return self.dim * (self.nb_cov + self._rank) - self._rank * (self._rank - 1) / 2 @property - def model_parameters(self) -> Dict[str, torch.Tensor]: + def _additional_properties_string(self) -> str: """ - Property representing the model parameters. + Property representing the additional properties string. Returns ------- - Dict[str, torch.Tensor] - The model parameters. + str + The additional properties string. """ - return {"coef": self.coef, "components": self.components} - - def _smart_init_model_parameters(self): - """ - Initialize the model parameters smartly. - """ - if not hasattr(self, "_coef"): - super()._smart_init_coef() - if not hasattr(self, "_components"): - self._components = _init_components( - self._endog, self._exog, self._coef, self._rank - ) - - def _random_init_model_parameters(self): - """ - Randomly initialize the model parameters. - """ - super()._random_init_coef() - self._components = torch.randn((self.dim, self._rank)).to(DEVICE) - - def _random_init_latent_parameters(self): - """ - Randomly initialize the latent parameters. - """ - self._latent_sqrt_var = ( - 1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE) - ) - self._latent_mean = torch.ones((self.n_samples, self._rank)).to(DEVICE) - - def _smart_init_latent_parameters(self): - """ - Initialize the latent parameters smartly. - """ - if not hasattr(self, "_latent_mean"): - self._latent_mean = ( - _init_latent_mean( - self._endog, - self._exog, - self._offsets, - self._coef, - self._components, - ) - .to(DEVICE) - .detach() - ) - if not hasattr(self, "_latent_sqrt_var"): - self._latent_sqrt_var = ( - 1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE) - ) - - @property - def _list_of_parameters_needing_gradient(self): - """ - Property representing the list of parameters needing gradient. - - Returns - ------- - List[torch.Tensor] - The list of parameters needing gradient. - """ - if self._coef is None: - return [self._components, self._latent_mean, self._latent_sqrt_var] - return [self._components, self._coef, self._latent_mean, self._latent_sqrt_var] - - def compute_elbo(self) -> torch.Tensor: - """ - Compute the evidence lower bound (ELBO). - - Returns - ------- - torch.Tensor - The ELBO value. - """ - return elbo_plnpca( - self._endog, - self._exog, - self._offsets, - self._latent_mean, - self._latent_sqrt_var, - self._components, - self._coef, - ) - - @property - def number_of_parameters(self) -> int: - """ - Property representing the number of parameters. - - Returns - ------- - int - The number of parameters. - """ - return self.dim * (self.nb_cov + self._rank) - self._rank * (self._rank - 1) / 2 - - @property - def _additional_properties_string(self) -> str: - """ - Property representing the additional properties string. - - Returns - ------- - str - The additional properties string. - """ - return ".projected_latent_variables" + return ".projected_latent_variables" @property def _additional_methods_string(self) -> str: @@ -3052,7 +3142,7 @@ class PlnPCA(_model): @property def covariance(self) -> torch.Tensor: """ - Property representing the covariance a posteriori of the latent variables. + Property representing the covariance of the latent variables. Returns ------- @@ -3094,18 +3184,6 @@ class PlnPCA(_model): """ return f" {self.rank} principal component." - @property - def latent_variables(self) -> torch.Tensor: - """ - Property representing the latent variables. - - Returns - ------- - torch.Tensor - The latent variables of size (n_samples, dim). - """ - return torch.matmul(self._latent_mean, self._components.T).detach() - @property def projected_latent_variables(self) -> torch.Tensor: """ @@ -3123,7 +3201,7 @@ class PlnPCA(_model): """ Orthogonal components of the model. """ - return torch.linalg.qr(self._components, "reduced")[0] + return torch.linalg.qr(self._components, "reduced")[0].cpu() @property def components(self) -> torch.Tensor: @@ -3165,7 +3243,7 @@ class PlnPCA(_model): Parameters ---------- project : bool, optional - Whether to project the latent variables, by default True. + Whether to project the latent variables, by default False. """, returns=""" torch.Tensor @@ -3182,81 +3260,996 @@ class PlnPCA(_model): >>> print(transformed_endog_high_dim.shape) """, ) - def transform(self, project: bool = True) -> torch.Tensor: + def transform(self, project: bool = False) -> torch.Tensor: if project is True: return self.projected_latent_variables return self.latent_variables + @property + @_add_doc( + _model, + example=""" + >>> from pyPLNmodels import PlnPCA, get_real_count_data + >>> endog = get_real_count_data(return_labels=False) + >>> pca = PlnPCA(endog,add_const = True) + >>> pca.fit() + >>> print(pca.latent_variables.shape) + """, + ) + def latent_variables(self) -> torch.Tensor: + return torch.matmul(self.latent_mean, self.components.T) -class ZIPln(Pln): - _NAME = "ZIPln" - - _pi: torch.Tensor - _coef_inflation: torch.Tensor - _dirac: torch.Tensor + @_add_doc( + _model, + example=""" + >>> from pyPLNmodels import PlnPCA, get_real_count_data + >>> endog = get_real_count_data(return_labels = False) + >>> pca = PlnPCA(endog,add_const = True) + >>> pca.fit() + >>> elbo = pca.compute_elbo() + >>> print("elbo", elbo) + >>> print("loglike/n", pca.loglike/pca.n_samples) + """, + ) + def compute_elbo(self) -> torch.Tensor: + return elbo_plnpca( + self._endog, + self._exog, + self._offsets, + self._latent_mean, + self._latent_sqrt_var, + self._components, + self._coef, + ) - @property - def _description(self): - return "with full covariance model and zero-inflation." + @_add_doc(_model) + def _compute_elbo_b(self) -> torch.Tensor: + return elbo_plnpca( + self._endog_b, + self._exog_b, + self._offsets_b, + self._latent_mean_b, + self._latent_sqrt_var_b, + self._components, + self._coef, + ) + @_add_doc(_model) def _random_init_model_parameters(self): - super()._random_init_model_parameters() - self._coef_inflation = torch.randn(self.nb_cov, self.dim) - self._covariance = torch.diag(torch.ones(self.dim)).to(DEVICE) + super()._random_init_coef() + self._components = torch.randn((self.dim, self._rank)).to(DEVICE) - # should change the good initialization, especially for _coef_inflation + @_add_doc(_model) def _smart_init_model_parameters(self): - super()._smart_init_model_parameters() - if not hasattr(self, "_covariance"): - self._covariance = _init_covariance(self._endog, self._exog, self._coef) - if not hasattr(self, "_coef_inflation"): - self._coef_inflation = torch.randn(self.nb_cov, self.dim) + if not hasattr(self, "_coef"): + super()._smart_init_coef() + if not hasattr(self, "_components"): + self._components = _init_components(self._endog, self._rank) + @_add_doc(_model) def _random_init_latent_parameters(self): - self._dirac = self._endog == 0 - self._latent_mean = torch.randn(self.n_samples, self.dim) - self._latent_sqrt_var = torch.randn(self.n_samples, self.dim) - self._pi = ( - torch.empty(self.n_samples, self.dim).uniform_(0, 1).to(DEVICE) - * self._dirac + """ + Randomly initialize the latent parameters. + """ + self._latent_sqrt_var = ( + 1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE) ) + self._latent_mean = torch.ones((self.n_samples, self._rank)).to(DEVICE) - def compute_elbo(self): - return elbo_zi_pln( - self._endog, - self._exog, - self._offsets, - self._latent_mean, - self._latent_sqrt_var, - self._pi, - self._covariance, - self._coef, - self._coef_inflation, - self._dirac, - ) + @_add_doc(_model) + def _smart_init_latent_parameters(self): + if not hasattr(self, "_latent_mean"): + self._latent_mean = ( + _init_latent_mean( + self._endog, + self._exog, + self._offsets, + self._coef, + self._components, + ) + .to(DEVICE) + .detach() + ) + if not hasattr(self, "_latent_sqrt_var"): + self._latent_sqrt_var = ( + 1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE) + ) @property + @_add_doc(_model) def _list_of_parameters_needing_gradient(self): - return [self._latent_mean, self._latent_sqrt_var, self._coef_inflation] + if self._coef is None: + return [self._components, self._latent_mean, self._latent_sqrt_var] + return [self._components, self._coef, self._latent_mean, self._latent_sqrt_var] - def _update_closed_forms(self): - self._coef = _closed_formula_coef(self._exog, self._latent_mean) - self._covariance = _closed_formula_covariance( - self._exog, - self._latent_mean, - self._latent_sqrt_var, - self._coef, - self.n_samples, - ) - self._pi = _closed_formula_pi( - self._offsets, - self._latent_mean, - self._latent_sqrt_var, - self._dirac, - self._exog, - self._coef_inflation, - ) + @property + @_add_doc(_model) + def model_parameters(self) -> Dict[str, torch.Tensor]: + return {"coef": self.coef, "components": self.components} @property - def number_of_parameters(self): - return self.dim * (2 * self.nb_cov + (self.dim + 1) / 2) + @_add_doc(_model) + def latent_parameters(self): + return { + "latent_sqrt_var": self.latent_sqrt_var, + "latent_mean": self.latent_mean, + } + + +class ZIPln(_model): + _NAME = "ZIPln" + + _latent_prob: torch.Tensor + _coef_inflation: torch.Tensor + _dirac: torch.Tensor + + def __init__( + self, + endog: Optional[Union[torch.Tensor, np.ndarray, pd.DataFrame]], + *, + exog: Optional[Union[torch.Tensor, np.ndarray, pd.DataFrame]] = None, + offsets: Optional[Union[torch.Tensor, np.ndarray, pd.DataFrame]] = None, + offsets_formula: str = "logsum", + dict_initialization: Optional[Dict[str, torch.Tensor]] = None, + take_log_offsets: bool = False, + add_const: bool = True, + use_closed_form_prob: bool = False, + ): + """ + Initializes the ZIPln class. + + Parameters + ---------- + endog : Union[torch.Tensor, np.ndarray, pd.DataFrame] + The count data. + exog : Union[torch.Tensor, np.ndarray, pd.DataFrame], optional(keyword-only) + The covariate data. Defaults to None. + offsets : Union[torch.Tensor, np.ndarray, pd.DataFrame], optional(keyword-only) + The offsets data. Defaults to None. + offsets_formula : str, optional(keyword-only) + The formula for offsets. Defaults to "logsum". Overriden if + offsets is not None. + dict_initialization : dict, optional(keyword-only) + The initialization dictionary. Defaults to None. + take_log_offsets : bool, optional(keyword-only) + Whether to take the log of offsets. Defaults to False. + add_const : bool, optional(keyword-only) + Whether to add a column of one in the exog. Defaults to True. + If exog is None, add_const is set to True anyway and a warnings + is launched. + use_closed_form_prob : bool, optional + Whether or not use the closed formula for the latent probability. + Default is False. + Raises + ------ + ValueError + If the batch_size is greater than the number of samples, or not int. + Returns + ------- + A ZIPln object + See also + -------- + :func:`pyPLNmodels.ZIPln.from_formula` + Examples + -------- + >>> from pyPLNmodels import ZIPln, get_real_count_data + >>> endog= get_real_count_data() + >>> zi = ZIPln(endog, add_const = True) + >>> zi.fit() + >>> print(zi) + """ + self._use_closed_form_prob = use_closed_form_prob + if exog is None and add_const is False: + msg = "No covariates has been given. An intercept is added since " + msg += "a ZIPln must have at least an intercept." + warnings.warn(msg) + add_const = True + super().__init__( + endog=endog, + exog=exog, + offsets=offsets, + offsets_formula=offsets_formula, + dict_initialization=dict_initialization, + take_log_offsets=take_log_offsets, + add_const=add_const, + ) + + def _extract_batch(self, batch): + super()._extract_batch(batch) + self._dirac_b = batch[5] + if self._use_closed_form_prob is False: + self._latent_prob_b = batch[6] + + def _return_batch(self, indices, beginning, end): + pln_batch = super()._return_batch(indices, beginning, end) + to_take = torch.tensor(indices[beginning:end]).to(DEVICE) + batch = pln_batch + (torch.index_select(self._dirac, 0, to_take),) + if self._use_closed_form_prob is False: + to_return = torch.index_select(self._latent_prob, 0, to_take) + return batch + (torch.index_select(self._latent_prob, 0, to_take),) + return batch + + @classmethod + def from_formula( + cls, + formula: str, + data: Dict[str, Union[torch.Tensor, np.ndarray, pd.DataFrame]], + *, + offsets_formula: str = "logsum", + dict_initialization: Optional[Dict[str, torch.Tensor]] = None, + take_log_offsets: bool = False, + use_closed_form_prob: bool = False, + ): + """ + Create a ZIPln instance from a formula and data. + + Parameters + ---------- + formula : str + The formula. + data : dict + The data dictionary. Each value can be either a torch.Tensor, + a np.ndarray or pd.DataFrame + offsets_formula : str, optional(keyword-only) + The formula for offsets. Defaults to "logsum". + dict_initialization : dict, optional(keyword-only) + The initialization dictionary. Defaults to None. + take_log_offsets : bool, optional(keyword-only) + Whether to take the log of offsets. Defaults to False. + use_closed_form_prob : bool, optional + Whether or not use the closed formula for the latent probability. + Default is False. + Returns + ------- + A ZIPln object + See also + -------- + :class:`pyPLNmodels.ZIPln` + :func:`pyPLNmodels.ZIPln.__init__` + Examples + -------- + >>> from pyPLNmodels import ZIPln, get_real_count_data + >>> endog = get_real_count_data() + >>> data = {"endog": endog} + >>> zi = ZIPln.from_formula("endog ~ 1", data = data) + """ + endog, exog, offsets = _extract_data_from_formula(formula, data) + return cls( + endog, + exog=exog, + offsets=offsets, + offsets_formula=offsets_formula, + dict_initialization=dict_initialization, + take_log_offsets=take_log_offsets, + add_const=False, + use_closed_form_prob=use_closed_form_prob, + ) + + @_add_doc( + _model, + example=""" + >>> from pyPLNmodels import ZIPln, get_real_count_data + >>> endog = get_real_count_data() + >>> zi = ZIPln(endog,add_const = True) + >>> zi.fit() + >>> print(zi) + """, + ) + def fit( + self, + nb_max_iteration: int = 50000, + *, + lr: float = 0.01, + tol: float = 1e-3, + do_smart_init: bool = True, + verbose: bool = False, + batch_size: int = None, + ): + super().fit( + nb_max_iteration, + lr=lr, + tol=tol, + do_smart_init=do_smart_init, + verbose=verbose, + batch_size=batch_size, + ) + + @_add_doc( + _model, + example=""" + >>> import matplotlib.pyplot as plt + >>> from pyPLNmodels import ZIPln, get_real_count_data + >>> endog, labels = get_real_count_data(return_labels = True) + >>> zi = ZIPln(endog,add_const = True) + >>> zi.fit() + >>> zi.plot_expected_vs_true() + >>> plt.show() + >>> zi.plot_expected_vs_true(colors = labels) + >>> plt.show() + """, + ) + def plot_expected_vs_true(self, ax=None, colors=None): + super().plot_expected_vs_true(ax=ax, colors=colors) + + @property + def _description(self): + return "with full covariance model and zero-inflation." + + def _random_init_model_parameters(self): + self._coef_inflation = torch.randn(self.nb_cov, self.dim).to(DEVICE) + self._coef = torch.randn(self.nb_cov, self.dim).to(DEVICE) + self._components = torch.randn(self.dim, self.dim).to(DEVICE) + + # should change the good initialization for _coef_inflation + def _smart_init_model_parameters(self): + # init of _coef. + super()._smart_init_coef() + if not hasattr(self, "_covariance"): + self._components = _init_components(self._endog, self.dim) + + if not hasattr(self, "_coef_inflation"): + self._coef_inflation = torch.randn(self.nb_cov, self.dim).to(DEVICE) + # for j in range(self.exog.shape[1]): + # Y_j = self._endog[:,j].numpy() + # offsets_j = self.offsets[:,j].numpy() + # exog = self.exog[:,j].unsqueeze(1).numpy() + # undzi = ZeroInflatedPoisson(endog=Y_j,exog = exog, exog_infl = exog, inflation='logit', offset = offsets_j) + # zip_training_results = undzi.fit() + # self._coef_inflation[:,j] = zip_training_results.params[1] + + def _random_init_latent_parameters(self): + self._latent_mean = torch.randn(self.n_samples, self.dim).to(DEVICE) + self._latent_sqrt_var = torch.randn(self.n_samples, self.dim).to(DEVICE) + self._latent_prob = ( + ( + torch.empty(self.n_samples, self.dim).uniform_(0, 1).to(DEVICE) + * self._dirac + ) + .double() + .to(DEVICE) + ) + + def _smart_init_latent_parameters(self): + self._random_init_latent_parameters() + + @property + def _covariance(self): + return self._components @ (self._components.T) + + def _get_max_components(self): + """ + Method for getting the maximum number of components. + + Returns + ------- + int + The maximum number of components. + """ + return self.dim + + @property + def components(self) -> torch.Tensor: + """ + Property representing the components. + + Returns + ------- + torch.Tensor + The components. + """ + return self._cpu_attribute_or_none("_components") + + @property + def latent_variables(self) -> tuple([torch.Tensor, torch.Tensor]): + """ + Property representing the latent variables. Two latent + variables are available if exog is not None + + Returns + ------- + tuple(torch.Tensor, torch.Tensor) + The latent variables of a classic Pln model (size (n_samples, dim)) + and zero inflated latent variables of size (n_samples, dim). + Examples + -------- + >>> from pyPLNmodels import ZIPln, get_real_count_data + >>> endog, labels = get_real_count_data(return_labels = True) + >>> zi = ZIPln(endog,add_const = True) + >>> zi.fit() + >>> latent_mean, latent_inflated = zi.latent_variables + >>> print(latent_mean.shape) + >>> print(latent_inflated.shape) + """ + return self.latent_mean, self.latent_prob + + def transform(self, return_latent_prob=False): + """ + Method for transforming the endog. Can be seen as a normalization of the endog. + + Parameters + ---------- + return_latent_prob: bool, optional + Wheter to return or not the latent_probability of zero inflation. + Returns + ------- + The latent mean if `return_latent_prob` is False and (latent_mean, latent_prob) else. + """ + if return_latent_prob is True: + return self.latent_variables + return self.latent_mean + + def _endog_predictions(self): + return torch.exp( + self._offsets + self._latent_mean + 1 / 2 * self._latent_sqrt_var**2 + ) * (1 - self._latent_prob) + + @property + def coef_inflation(self): + """ + Property representing the coefficients of the inflation. + + Returns + ------- + torch.Tensor or None + The coefficients or None. + """ + return self._cpu_attribute_or_none("_coef_inflation") + + @coef_inflation.setter + @_array2tensor + def coef_inflation( + self, coef_inflation: Union[torch.Tensor, np.ndarray, pd.DataFrame] + ): + """ + Setter for the coef_inflation property. + + Parameters + ---------- + coef : Union[torch.Tensor, np.ndarray, pd.DataFrame] + The coefficients. + + Raises + ------ + ValueError + If the shape of the coef is incorrect. + """ + if coef_inflation.shape != (self.nb_cov, self.dim): + raise ValueError( + f"Wrong shape for the coef. Expected {(self.nb_cov, self.dim)}, got {coef_inflation.shape}" + ) + self._coef_inflation = coef_inflation + + @_model.latent_sqrt_var.setter + @_array2tensor + def latent_sqrt_var( + self, latent_sqrt_var: Union[torch.Tensor, np.ndarray, pd.DataFrame] + ): + """ + Setter for the latent variance property. + + Parameters + ---------- + latent_sqrt_var : Union[torch.Tensor, np.ndarray, pd.DataFrame] + The latent square root of the variance. + + Raises + ------ + ValueError + If the shape of the latent variance is incorrect. + """ + if latent_sqrt_var.shape != (self.n_samples, self.dim): + raise ValueError( + f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_sqrt_var.shape}" + ) + self._latent_sqrt_var = latent_sqrt_var + + def _project_parameters(self): + self._project_latent_prob() + + def _project_latent_prob(self): + """ + Project the latent probability since it must be between 0 and 1. + """ + if self._use_closed_form_prob is False: + with torch.no_grad(): + self._latent_prob = torch.maximum( + self._latent_prob, torch.tensor([0]).to(DEVICE) + ) + self._latent_prob = torch.minimum( + self._latent_prob, torch.tensor([1]).to(DEVICE) + ) + self._latent_prob *= self._dirac + + @property + def covariance(self) -> torch.Tensor: + """ + Property representing the covariance of the latent variables. + + Returns + ------- + Optional[torch.Tensor] + The covariance tensor or None if components are not present. + """ + return self._cpu_attribute_or_none("_covariance") + + @components.setter + @_array2tensor + def components(self, components: torch.Tensor): + """ + Setter for the components. + + Parameters + ---------- + components : torch.Tensor + The components to set. + + Raises + ------ + ValueError + If the components have an invalid shape. + """ + if components.shape != (self.dim, self.dim): + raise ValueError( + f"Wrong shape. Expected {self.dim, self.dim}, got {components.shape}" + ) + self._components = components + + @property + def latent_prob(self): + return self._cpu_attribute_or_none("_latent_prob") + + @latent_prob.setter + @_array2tensor + def latent_prob(self, latent_prob: Union[torch.Tensor, np.ndarray, pd.DataFrame]): + if self._use_closed_form_prob is True: + raise ValueError( + "Can not set the latent prob when the closed form is used." + ) + if latent_prob.shape != (self.n_samples, self.dim): + raise ValueError( + f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_prob.shape}" + ) + if torch.max(latent_prob) > 1 or torch.min(latent_prob) < 0: + raise ValueError(f"Wrong value. All values should be between 0 and 1.") + if torch.norm(latent_prob * (self._endog == 0) - latent_prob) > 0.00000001: + raise ValueError( + "You can not assign non zeros inflation probabilities to non zero counts." + ) + self._latent_prob = latent_prob + + @property + def closed_formula_latent_prob(self): + """ + The closed form for the latent probability. + """ + return closed_formula_latent_prob( + self._exog, self._coef, self._coef_inflation, self._covariance, self._dirac + ) + + def compute_elbo(self): + if self._use_closed_form_prob is True: + latent_prob = self.closed_formula_latent_prob + else: + latent_prob = self._latent_prob + return elbo_zi_pln( + self._endog, + self._exog, + self._offsets, + self._latent_mean, + self._latent_sqrt_var, + latent_prob, + self._components, + self._coef, + self._coef_inflation, + self._dirac, + ) + + def _compute_elbo_b(self): + if self._use_closed_form_prob is True: + latent_prob_b = _closed_formula_latent_prob( + self._exog_b, + self._coef, + self._coef_inflation, + self._covariance, + self._dirac_b, + ) + else: + latent_prob_b = self._latent_prob_b + return elbo_zi_pln( + self._endog_b, + self._exog_b, + self._offsets_b, + self._latent_mean_b, + self._latent_sqrt_var_b, + latent_prob_b, + self._components, + self._coef, + self._coef_inflation, + self._dirac_b, + ) + + @property + def number_of_parameters(self): + return self.dim * (2 * self.nb_cov + (self.dim + 1) / 2) + + @property + @_add_doc(_model) + def _list_of_parameters_needing_gradient(self): + list_parameters = [ + self._latent_mean, + self._latent_sqrt_var, + self._components, + ] + if self._use_closed_form_prob is False: + list_parameters.append(self._latent_prob) + if self._exog is not None: + list_parameters.append(self._coef) + list_parameters.append(self._coef_inflation) + return list_parameters + + @property + @_add_doc(_model) + def model_parameters(self) -> Dict[str, torch.Tensor]: + return { + "coef": self.coef, + "components": self.components, + "coef_inflation": self.coef_inflation, + } + + def predict_prob_inflation( + self, exog: Union[torch.Tensor, np.ndarray, pd.DataFrame] + ): + """ + Method for estimating the probability of a zero coming from the zero inflated component. + + Parameters + ---------- + exog : Union[torch.Tensor, np.ndarray, pd.DataFrame] + The exog. + + Returns + ------- + torch.Tensor + The predicted values. + + Raises + ------ + RuntimeError + If the shape of the exog is incorrect. + + Notes + ----- + - The mean sigmoid(exog @ coef_inflation) is returned. + - `exog` should have the shape `(_, nb_cov)`, where `nb_cov` is the number of exog variables. + """ + if exog is not None and self.nb_cov == 0: + raise AttributeError("No exog in the model, can't predict") + if exog.shape[-1] != self.nb_cov: + error_string = f"X has wrong shape ({exog.shape}). Should" + error_string += f" be (_, {self.nb_cov})." + raise RuntimeError(error_string) + return torch.sigmoid(exog @ self.coef_inflation) + + @property + @_add_doc(_model) + def latent_parameters(self): + latent_param = { + "latent_sqrt_var": self.latent_sqrt_var, + "latent_mean": self.latent_mean, + } + if self._use_closed_form_prob is False: + latent_param["latent_prob"] = self.latent_prob + return latent_param + + def grad_M(self): + if self.use_closed_form_prob is True: + latent_prob = self.closed_formula_latent_prob + else: + latent_prob = self._latent_prob + un_moins_prob = 1 - latent_prob + first = un_moins_prob * ( + self._endog + - torch.exp( + self._offsets + self._latent_mean + self.latent_sqrt_var**2 / 2 + ) + ) + MmoinsXB = self._latent_mean - self._exog @ self._coef + A = (un_moins_prob * MmoinsXB) @ torch.inverse(self._covariance) + diag_omega = torch.diag(torch.inverse(self._covariance)) + full_diag_omega = diag_omega.expand(self.exog.shape[0], -1) + second = -un_moins_prob * A + added = -full_diag_omega * latent_prob * un_moins_prob * (MmoinsXB) + return first + second + added + + def grad_S(self): + if self.use_closed_form_prob is True: + latent_prob = self.closed_formula_latent_prob + else: + latent_prob = self._latent_prob + Omega = torch.inverse(self.covariance) + un_moins_prob = 1 - latent_prob + first = un_moins_prob * torch.exp( + self._offsets + self._latent_mean + self._latent_sqrt_var**2 / 2 + ) + first = -torch.multiply(first, self._latent_sqrt_var) + sec = un_moins_prob * 1 / self._latent_sqrt_var + K = un_moins_prob * ( + torch.multiply( + torch.full((self.n_samples, 1), 1.0), torch.diag(Omega).unsqueeze(0) + ) + ) + third = -self._latent_sqrt_var * K + return first + sec + third + + def grad_theta(self): + if self.use_closed_form_prob is True: + latent_prob = self.closed_formula_latent_prob + else: + latent_prob = self._latent_prob + + un_moins_prob = 1 - latent_prob + MmoinsXB = self._latent_mean - self._exog @ self._coef + A = (un_moins_prob * MmoinsXB) @ torch.inverse(self._covariance) + diag_omega = torch.diag(torch.inverse(self._covariance)) + full_diag_omega = diag_omega.expand(self.exog.shape[0], -1) + added = latent_prob * (MmoinsXB) * full_diag_omega + A += added + second = -un_moins_prob * A + grad_no_closed_form = -self._exog.T @ second + if self.use_closed_form_prob is False: + return grad_no_closed_form + else: + XB_zero = self._exog @ self._coef_inflation + diag = torch.diag(self._covariance) + full_diag = diag.expand(self._exog.shape[0], -1) + XB = self._exog @ self._coef + derivative = d_h_x2(XB_zero, XB, full_diag, self._dirac) + grad_closed_form = self.gradients_closed_form_thetas(derivative) + return grad_closed_form + grad_no_closed_form + + def gradients_closed_form_thetas(self, derivative): + Omega = torch.inverse(self._covariance) + MmoinsXB = self._latent_mean - self._exog @ self._coef + s_rond_s = self._latent_sqrt_var**2 + latent_prob = self.closed_formula_latent_prob + A = torch.exp(self._offsets + self._latent_mean + s_rond_s / 2) + poiss_term = ( + self._endog * (self._offsets + self._latent_mean) + - A + - _log_stirling(self._endog) + ) + a = -self._exog.T @ (derivative * poiss_term) + b = self._exog.T @ ( + derivative * MmoinsXB * (((1 - latent_prob) * MmoinsXB) @ Omega) + ) + c = self._exog.T @ (derivative * (self._exog @ self._coef_inflation)) + first_d = derivative * torch.log(torch.abs(self._latent_sqrt_var)) + second_d = ( + 1 / 2 * derivative @ (torch.diag(torch.log(torch.diag(self._covariance)))) + ) + d = -self._exog.T @ (first_d - second_d) + e = -self._exog.T @ ( + derivative * (_trunc_log(latent_prob) - _trunc_log(1 - latent_prob)) + ) + first_f = ( + +1 + / 2 + * self._exog.T + @ (derivative * (s_rond_s @ torch.diag(torch.diag(Omega)))) + ) + second_f = ( + -1 + / 2 + * self._exog.T + @ derivative + @ torch.diag(torch.diag(Omega) * torch.diag(self._covariance)) + ) + full_diag_omega = torch.diag(Omega).expand(self.exog.shape[0], -1) + common = (MmoinsXB) ** 2 * (full_diag_omega) + new_f = -1 / 2 * self._exog.T @ (derivative * common * (1 - 2 * latent_prob)) + f = first_f + second_f + new_f + return a + b + c + d + e + f + + def grad_theta_0(self): + if self.use_closed_form_prob is True: + latent_prob = self.closed_formula_latent_prob + else: + latent_prob = self._latent_prob + grad_no_closed_form = self._exog.T @ latent_prob - self._exog.T @ ( + torch.exp(self._exog @ self._coef_inflation) + / (1 + torch.exp(self._exog @ self._coef_inflation)) + ) + if self.use_closed_form_prob is False: + return grad_no_closed_form + else: + grad_closed_form = self.gradients_closed_form_thetas( + latent_prob * (1 - latent_prob) + ) + return grad_closed_form + grad_no_closed_form + + def grad_C(self): + if self.use_closed_form_prob is True: + latent_prob = self.closed_formula_latent_prob + else: + latent_prob = self._latent_prob + omega = torch.inverse(self._covariance) + if self._coef is not None: + m_minus_xb = self._latent_mean - torch.mm(self._exog, self._coef) + else: + m_minus_xb = self._latent_mean + m_moins_xb_outer = torch.mm(m_minus_xb.T, m_minus_xb) + + un_moins_rho = 1 - latent_prob + + un_moins_rho_m_moins_xb = un_moins_rho * m_minus_xb + un_moins_rho_m_moins_xb_outer = ( + un_moins_rho_m_moins_xb.T @ un_moins_rho_m_moins_xb + ) + deter = ( + -self.n_samples + * torch.inverse(self._components @ (self._components.T)) + @ self._components + ) + sec_part_b_grad = ( + omega @ (un_moins_rho_m_moins_xb_outer) @ omega @ self._components + ) + b_grad = deter + sec_part_b_grad + + diag = torch.diag(self.covariance) + rho_t_unn = torch.sum(latent_prob, axis=0) + omega_unp = torch.sum(omega, axis=0) + K = torch.sum(un_moins_rho * self._latent_sqrt_var**2, axis=0) + diag * ( + rho_t_unn + ) + added = torch.sum(latent_prob * un_moins_rho * (m_minus_xb**2), axis=0) + K += added + first_part_grad = omega @ torch.diag_embed(K) @ omega @ self._components + x = torch.diag(omega) * rho_t_unn + second_part_grad = -torch.diag_embed(x) @ self._components + y = rho_t_unn + first = torch.multiply(y, 1 / torch.diag(self.covariance)).unsqueeze(1) + second = torch.full((1, self.dim), 1.0) + Diag = (first * second) * torch.eye(self.dim) + last_grad = Diag @ self._components + grad_no_closed_form = b_grad + first_part_grad + second_part_grad + last_grad + if self.use_closed_form_prob is False: + return grad_no_closed_form + else: + s_rond_s = self._latent_sqrt_var**2 + XB_zero = self._exog @ self._coef_inflation + XB = self._exog @ self._coef + A = torch.exp(self._offsets + self._latent_mean + s_rond_s / 2) + poiss_term = ( + self._endog * (self._offsets + self._latent_mean) + - A + - _log_stirling(self._endog) + ) + full_diag_sigma = diag.expand(self._exog.shape[0], -1) + full_diag_omega = torch.diag(omega).expand(self._exog.shape[0], -1) + H3 = d_h_x3(XB_zero, XB, full_diag_sigma, self._dirac) + poiss_term_H = poiss_term * H3 + a = ( + -2 + * ( + ((poiss_term_H.T @ torch.ones(self.n_samples, self.dim))) + * (torch.eye(self.dim)) + ) + @ self._components + ) + B_Omega = ((1 - latent_prob) * m_minus_xb) @ omega + K = H3 * B_Omega * m_minus_xb + b = ( + 2 + * ( + ( + (m_minus_xb * B_Omega * H3).T + @ torch.ones(self.n_samples, self.dim) + ) + * torch.eye(self.dim) + ) + @ self._components + ) + c = ( + 2 + * ( + ((XB_zero * H3).T @ torch.ones(self.n_samples, self.dim)) + * torch.eye(self.dim) + ) + @ self._components + ) + d = ( + -2 + * ( + ( + (torch.log(torch.abs(self._latent_sqrt_var)) * H3).T + @ torch.ones(self.n_samples, self.dim) + ) + * torch.eye(self.dim) + ) + @ self._components + ) + log_full_diag_sigma = torch.log(diag).expand(self._exog.shape[0], -1) + d += ( + ((log_full_diag_sigma * H3).T @ torch.ones(self.n_samples, self.dim)) + * torch.eye(self.dim) + ) @ self._components + e = ( + -2 + * ( + ( + ((_trunc_log(latent_prob) - _trunc_log(1 - latent_prob)) * H3).T + @ torch.ones(self.n_samples, self.dim) + ) + * torch.eye(self.dim) + ) + @ self._components + ) + f = ( + -( + ( + (full_diag_omega * (full_diag_sigma - s_rond_s) * H3).T + @ torch.ones(self.n_samples, self.dim) + ) + * torch.eye(self.dim) + ) + @ self._components + ) + f -= ( + ( + ((1 - 2 * latent_prob) * m_minus_xb**2 * full_diag_omega * H3).T + @ torch.ones(self.n_samples, self.dim) + ) + * torch.eye(self.dim) + ) @ self._components + grad_closed_form = a + b + c + d + e + f + return grad_closed_form + grad_no_closed_form + + def grad_rho(self): + if self.use_closed_form_prob is True: + latent_prob = self.closed_formula_latent_prob + else: + latent_prob = self._latent_prob + omega = torch.inverse(self._covariance) + s_rond_s = self._latent_sqrt_var * self._latent_sqrt_var + A = torch.exp(self._offsets + self._latent_mean + s_rond_s / 2) + first = ( + -self._endog * (self._offsets + self._latent_mean) + + A + + _log_stirling(self._endog) + ) + un_moins_prob = 1 - latent_prob + MmoinsXB = self._latent_mean - self._exog @ self._coef + A = (un_moins_prob * MmoinsXB) @ torch.inverse(self._covariance) + second = MmoinsXB * A + third = self._exog @ self._coef_inflation + fourth_first = -torch.log(torch.abs(self._latent_sqrt_var)) + fourth_second = ( + 1 + / 2 + * torch.multiply( + torch.full((self.n_samples, 1), 1.0), + torch.log(torch.diag(self.covariance)).unsqueeze(0), + ) + ) + fourth = fourth_first + fourth_second + fifth = _trunc_log(un_moins_prob) - _trunc_log(latent_prob) + sixth_first = ( + 1 + / 2 + * torch.multiply( + torch.full((self.n_samples, 1), 1.0), torch.diag(omega).unsqueeze(0) + ) + * s_rond_s + ) + sixth_second = ( + -1 + / 2 + * torch.multiply( + torch.full((self.n_samples, 1), 1.0), + (torch.diag(omega) * torch.diag(self._covariance)).unsqueeze(0), + ) + ) + sixth = sixth_first + sixth_second + full_diag_omega = torch.diag(omega).expand(self.exog.shape[0], -1) + seventh = -1 / 2 * (1 - 2 * latent_prob) * (MmoinsXB) ** 2 * (full_diag_omega) + return first + second + third + fourth + fifth + sixth + seventh diff --git a/pyPLNmodels/new_model.py b/pyPLNmodels/new_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2d4acd4598252de39d8f22ae17c52fa2f5adcc04 --- /dev/null +++ b/pyPLNmodels/new_model.py @@ -0,0 +1,9 @@ +from pyPLNmodels import ZIPln, get_real_count_data + + +endog = get_real_count_data() +zi = ZIPln(endog, add_const = True) +zi.fit(nb_max_iteration = 10) +zi.show() + + diff --git a/tests/conftest.py b/tests/conftest.py index 3a072f2016de6dd0e6e4b0240eb18fc04708a5b3..d89a919ac083c0c1bd4ae774fdcee79629015439 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,22 +1,16 @@ import sys -import glob -from functools import singledispatch import pytest import torch from pytest_lazyfixture import lazy_fixture as lf import pandas as pd from pyPLNmodels import load_model, load_plnpcacollection -from pyPLNmodels.models import Pln, PlnPCA, PlnPCAcollection +from pyPLNmodels.models import Pln, PlnPCA, PlnPCAcollection, ZIPln +from pyPLNmodels import get_simulated_count_data sys.path.append("../") -pytest_plugins = [ - fixture_file.replace("/", ".").replace(".py", "") - for fixture_file in glob.glob("src/**/tests/fixtures/[!__]*.py", recursive=True) -] - from tests.import_data import ( data_sim_0cov, @@ -43,6 +37,11 @@ def add_fixture_to_dict(my_dict, string_fixture): return my_dict +# zi = ZIPln(endog_sim_2cov, exog = exog_sim_2cov) +# zi.fit() +# print(zi) + + def add_list_of_fixture_to_dict( my_dict, name_of_list_of_fixtures, list_of_string_fixtures ): @@ -83,19 +82,26 @@ def convenientpln(*args, **kwargs): return Pln(*args, **kwargs) +def convenientzi(*args, **kwargs): + if isinstance(args[0], str): + return ZIPln.from_formula(*args, **kwargs) + return ZIPln(*args, **kwargs) + + def generate_new_model(model, *args, **kwargs): name_dir = model._directory_name - print("directory name", name_dir) name = model._NAME - if name in ("Pln", "PlnPCA"): - path = model._path_to_directory + name_dir + if name in ("Pln", "PlnPCA", "ZIPln"): + path = model._directory_name init = load_model(path) if name == "Pln": new = convenientpln(*args, **kwargs, dict_initialization=init) if name == "PlnPCA": new = convenient_PlnPCA(*args, **kwargs, dict_initialization=init) + if name == "ZIPln": + new = convenientzi(*args, **kwargs, dict_initialization=init) if name == "PlnPCAcollection": - init = load_plnpcacollection(name_dir) + init = load_plnpcacollection(model._directory_name) new = convenient_PlnPCAcollection(*args, **kwargs, dict_initialization=init) return new @@ -111,67 +117,67 @@ def cache(func): return new_func -params = [convenientpln, convenient_PlnPCA, convenient_PlnPCAcollection] +params = [convenientpln, convenient_PlnPCA, convenient_PlnPCAcollection, convenientzi] dict_fixtures = {} @pytest.fixture(params=params) -def simulated_pln_0cov_array(request): +def simulated_model_0cov_array(request): cls = request.param - pln = cls( + model = cls( endog_sim_0cov, exog=exog_sim_0cov, offsets=offsets_sim_0cov, add_const=False, ) - return pln + return model @pytest.fixture(params=params) @cache -def simulated_fitted_pln_0cov_array(request): +def simulated_fitted_model_0cov_array(request): cls = request.param - pln = cls( + model = cls( endog_sim_0cov, exog=exog_sim_0cov, offsets=offsets_sim_0cov, add_const=False, ) - pln.fit() - return pln + model.fit() + return model @pytest.fixture(params=params) -def simulated_pln_0cov_formula(request): +def simulated_model_0cov_formula(request): cls = request.param - pln = cls("endog ~ 0", data_sim_0cov) - return pln + model = cls("endog ~ 0", data_sim_0cov) + return model @pytest.fixture(params=params) @cache -def simulated_fitted_pln_0cov_formula(request): +def simulated_fitted_model_0cov_formula(request): cls = request.param - pln = cls("endog ~ 0", data_sim_0cov) - pln.fit() - return pln + model = cls("endog ~ 0", data_sim_0cov) + model.fit() + return model @pytest.fixture -def simulated_loaded_pln_0cov_formula(simulated_fitted_pln_0cov_formula): - simulated_fitted_pln_0cov_formula.save() +def simulated_loaded_model_0cov_formula(simulated_fitted_model_0cov_formula): + simulated_fitted_model_0cov_formula.save() return generate_new_model( - simulated_fitted_pln_0cov_formula, + simulated_fitted_model_0cov_formula, "endog ~ 0", data_sim_0cov, ) @pytest.fixture -def simulated_loaded_pln_0cov_array(simulated_fitted_pln_0cov_array): - simulated_fitted_pln_0cov_array.save() +def simulated_loaded_model_0cov_array(simulated_fitted_model_0cov_array): + simulated_fitted_model_0cov_array.save() return generate_new_model( - simulated_fitted_pln_0cov_array, + simulated_fitted_model_0cov_array, endog_sim_0cov, exog=exog_sim_0cov, offsets=offsets_sim_0cov, @@ -179,87 +185,95 @@ def simulated_loaded_pln_0cov_array(simulated_fitted_pln_0cov_array): ) -sim_pln_0cov_instance = [ - "simulated_pln_0cov_array", - "simulated_pln_0cov_formula", +sim_model_0cov_instance = [ + "simulated_model_0cov_array", + "simulated_model_0cov_formula", ] -instances = sim_pln_0cov_instance + instances +instances = sim_model_0cov_instance + instances dict_fixtures = add_list_of_fixture_to_dict( - dict_fixtures, "sim_pln_0cov_instance", sim_pln_0cov_instance + dict_fixtures, "sim_model_0cov_instance", sim_model_0cov_instance ) -sim_pln_0cov_fitted = [ - "simulated_fitted_pln_0cov_array", - "simulated_fitted_pln_0cov_formula", +sim_model_0cov_fitted = [ + "simulated_fitted_model_0cov_array", + "simulated_fitted_model_0cov_formula", ] dict_fixtures = add_list_of_fixture_to_dict( - dict_fixtures, "sim_pln_0cov_fitted", sim_pln_0cov_fitted + dict_fixtures, "sim_model_0cov_fitted", sim_model_0cov_fitted ) -sim_pln_0cov_loaded = [ - "simulated_loaded_pln_0cov_array", - "simulated_loaded_pln_0cov_formula", + +sim_model_0cov_loaded = [ + "simulated_loaded_model_0cov_array", + "simulated_loaded_model_0cov_formula", ] dict_fixtures = add_list_of_fixture_to_dict( - dict_fixtures, "sim_pln_0cov_loaded", sim_pln_0cov_loaded + dict_fixtures, "sim_model_0cov_loaded", sim_model_0cov_loaded +) + +sim_model_0cov = sim_model_0cov_instance + sim_model_0cov_fitted + sim_model_0cov_loaded +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "sim_model_0cov", sim_model_0cov ) -sim_pln_0cov = sim_pln_0cov_instance + sim_pln_0cov_fitted + sim_pln_0cov_loaded -dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln_0cov", sim_pln_0cov) +sim_model_0cov_fitted_and_loaded = sim_model_0cov_fitted + sim_model_0cov_loaded +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "sim_model_0cov_fitted_and_loaded", sim_model_0cov_fitted_and_loaded +) @pytest.fixture(params=params) @cache -def simulated_pln_2cov_array(request): +def simulated_model_2cov_array(request): cls = request.param - pln_full = cls( + model = cls( endog_sim_2cov, exog=exog_sim_2cov, offsets=offsets_sim_2cov, add_const=False, ) - return pln_full + return model @pytest.fixture -def simulated_fitted_pln_2cov_array(simulated_pln_2cov_array): - simulated_pln_2cov_array.fit() - return simulated_pln_2cov_array +def simulated_fitted_model_2cov_array(simulated_model_2cov_array): + simulated_model_2cov_array.fit() + return simulated_model_2cov_array @pytest.fixture(params=params) @cache -def simulated_pln_2cov_formula(request): +def simulated_model_2cov_formula(request): cls = request.param - pln_full = cls("endog ~ 0 + exog", data_sim_2cov) - return pln_full + model = cls("endog ~ 0 + exog", data_sim_2cov) + return model @pytest.fixture -def simulated_fitted_pln_2cov_formula(simulated_pln_2cov_formula): - simulated_pln_2cov_formula.fit() - return simulated_pln_2cov_formula +def simulated_fitted_model_2cov_formula(simulated_model_2cov_formula): + simulated_model_2cov_formula.fit() + return simulated_model_2cov_formula @pytest.fixture -def simulated_loaded_pln_2cov_formula(simulated_fitted_pln_2cov_formula): - simulated_fitted_pln_2cov_formula.save() +def simulated_loaded_model_2cov_formula(simulated_fitted_model_2cov_formula): + simulated_fitted_model_2cov_formula.save() return generate_new_model( - simulated_fitted_pln_2cov_formula, + simulated_fitted_model_2cov_formula, "endog ~0 + exog", data_sim_2cov, ) @pytest.fixture -def simulated_loaded_pln_2cov_array(simulated_fitted_pln_2cov_array): - simulated_fitted_pln_2cov_array.save() +def simulated_loaded_model_2cov_array(simulated_fitted_model_2cov_array): + simulated_fitted_model_2cov_array.save() return generate_new_model( - simulated_fitted_pln_2cov_array, + simulated_fitted_model_2cov_array, endog_sim_2cov, exog=exog_sim_2cov, offsets=offsets_sim_2cov, @@ -267,147 +281,154 @@ def simulated_loaded_pln_2cov_array(simulated_fitted_pln_2cov_array): ) -sim_pln_2cov_instance = [ - "simulated_pln_2cov_array", - "simulated_pln_2cov_formula", +sim_model_2cov_instance = [ + "simulated_model_2cov_array", + "simulated_model_2cov_formula", ] -instances = sim_pln_2cov_instance + instances +sim_model_instance = sim_model_0cov_instance + sim_model_2cov_instance dict_fixtures = add_list_of_fixture_to_dict( - dict_fixtures, "sim_pln_2cov_instance", sim_pln_2cov_instance + dict_fixtures, "sim_model_instance", sim_model_instance ) +instances = sim_model_2cov_instance + instances + -sim_pln_2cov_fitted = [ - "simulated_fitted_pln_2cov_array", - "simulated_fitted_pln_2cov_formula", +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "sim_model_2cov_instance", sim_model_2cov_instance +) +sim_model_2cov_fitted = [ + "simulated_fitted_model_2cov_array", + "simulated_fitted_model_2cov_formula", ] dict_fixtures = add_list_of_fixture_to_dict( - dict_fixtures, "sim_pln_2cov_fitted", sim_pln_2cov_fitted + dict_fixtures, "sim_model_2cov_fitted", sim_model_2cov_fitted ) -sim_pln_2cov_loaded = [ - "simulated_loaded_pln_2cov_array", - "simulated_loaded_pln_2cov_formula", +sim_model_2cov_loaded = [ + "simulated_loaded_model_2cov_array", + "simulated_loaded_model_2cov_formula", ] dict_fixtures = add_list_of_fixture_to_dict( - dict_fixtures, "sim_pln_2cov_loaded", sim_pln_2cov_loaded + dict_fixtures, "sim_model_2cov_loaded", sim_model_2cov_loaded ) -sim_pln_2cov = sim_pln_2cov_instance + sim_pln_2cov_fitted + sim_pln_2cov_loaded -dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln_2cov", sim_pln_2cov) +sim_model_2cov = sim_model_2cov_instance + sim_model_2cov_fitted + sim_model_2cov_loaded +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "sim_model_2cov", sim_model_2cov +) @pytest.fixture(params=params) @cache -def real_pln_intercept_array(request): +def real_model_intercept_array(request): cls = request.param - pln_full = cls(endog_real, add_const=True) - return pln_full + model = cls(endog_real, add_const=True) + return model @pytest.fixture -def real_fitted_pln_intercept_array(real_pln_intercept_array): - real_pln_intercept_array.fit() - return real_pln_intercept_array +def real_fitted_model_intercept_array(real_model_intercept_array): + real_model_intercept_array.fit() + return real_model_intercept_array @pytest.fixture(params=params) @cache -def real_pln_intercept_formula(request): +def real_model_intercept_formula(request): cls = request.param - pln_full = cls("endog ~ 1", data_real) - return pln_full + model = cls("endog ~ 1", data_real) + return model @pytest.fixture -def real_fitted_pln_intercept_formula(real_pln_intercept_formula): - real_pln_intercept_formula.fit() - return real_pln_intercept_formula +def real_fitted_model_intercept_formula(real_model_intercept_formula): + real_model_intercept_formula.fit() + return real_model_intercept_formula @pytest.fixture -def real_loaded_pln_intercept_formula(real_fitted_pln_intercept_formula): - real_fitted_pln_intercept_formula.save() +def real_loaded_model_intercept_formula(real_fitted_model_intercept_formula): + real_fitted_model_intercept_formula.save() return generate_new_model( - real_fitted_pln_intercept_formula, "endog ~ 1", data=data_real + real_fitted_model_intercept_formula, "endog ~ 1", data=data_real ) @pytest.fixture -def real_loaded_pln_intercept_array(real_fitted_pln_intercept_array): - real_fitted_pln_intercept_array.save() +def real_loaded_model_intercept_array(real_fitted_model_intercept_array): + real_fitted_model_intercept_array.save() return generate_new_model( - real_fitted_pln_intercept_array, + real_fitted_model_intercept_array, endog_real, add_const=True, ) -real_pln_instance = [ - "real_pln_intercept_array", - "real_pln_intercept_formula", +real_model_instance = [ + "real_model_intercept_array", + "real_model_intercept_formula", ] -instances = real_pln_instance + instances +instances = real_model_instance + instances dict_fixtures = add_list_of_fixture_to_dict( - dict_fixtures, "real_pln_instance", real_pln_instance + dict_fixtures, "real_model_instance", real_model_instance ) -real_pln_fitted = [ - "real_fitted_pln_intercept_array", - "real_fitted_pln_intercept_formula", +real_model_fitted = [ + "real_fitted_model_intercept_array", + "real_fitted_model_intercept_formula", ] dict_fixtures = add_list_of_fixture_to_dict( - dict_fixtures, "real_pln_fitted", real_pln_fitted + dict_fixtures, "real_model_fitted", real_model_fitted ) -real_pln_loaded = [ - "real_loaded_pln_intercept_array", - "real_loaded_pln_intercept_formula", +real_model_loaded = [ + "real_loaded_model_intercept_array", + "real_loaded_model_intercept_formula", ] dict_fixtures = add_list_of_fixture_to_dict( - dict_fixtures, "real_pln_loaded", real_pln_loaded + dict_fixtures, "real_model_loaded", real_model_loaded ) -sim_loaded_pln = sim_pln_0cov_loaded + sim_pln_2cov_loaded +sim_loaded_model = sim_model_0cov_loaded + sim_model_2cov_loaded -loaded_pln = real_pln_loaded + sim_loaded_pln -dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "loaded_pln", loaded_pln) +loaded_model = real_model_loaded + sim_loaded_model +dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "loaded_model", loaded_model) -simulated_pln_fitted = sim_pln_0cov_fitted + sim_pln_2cov_fitted +simulated_model_fitted = sim_model_0cov_fitted + sim_model_2cov_fitted dict_fixtures = add_list_of_fixture_to_dict( - dict_fixtures, "simulated_pln_fitted", simulated_pln_fitted + dict_fixtures, "simulated_model_fitted", simulated_model_fitted ) -fitted_pln = real_pln_fitted + simulated_pln_fitted -dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "fitted_pln", fitted_pln) +fitted_model = real_model_fitted + simulated_model_fitted +dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "fitted_model", fitted_model) -loaded_and_fitted_sim_pln = simulated_pln_fitted + sim_loaded_pln -loaded_and_fitted_real_pln = real_pln_fitted + real_pln_loaded +loaded_and_fitted_sim_model = simulated_model_fitted + sim_loaded_model +loaded_and_fitted_real_model = real_model_fitted + real_model_loaded dict_fixtures = add_list_of_fixture_to_dict( - dict_fixtures, "loaded_and_fitted_real_pln", loaded_and_fitted_real_pln + dict_fixtures, "loaded_and_fitted_real_model", loaded_and_fitted_real_model ) dict_fixtures = add_list_of_fixture_to_dict( - dict_fixtures, "loaded_and_fitted_sim_pln", loaded_and_fitted_sim_pln + dict_fixtures, "loaded_and_fitted_sim_model", loaded_and_fitted_sim_model ) -loaded_and_fitted_pln = fitted_pln + loaded_pln +loaded_and_fitted_model = fitted_model + loaded_model dict_fixtures = add_list_of_fixture_to_dict( - dict_fixtures, "loaded_and_fitted_pln", loaded_and_fitted_pln + dict_fixtures, "loaded_and_fitted_model", loaded_and_fitted_model ) -real_pln = real_pln_instance + real_pln_fitted + real_pln_loaded -dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "real_pln", real_pln) +real_model = real_model_instance + real_model_fitted + real_model_loaded +dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "real_model", real_model) -sim_pln = sim_pln_2cov + sim_pln_0cov -dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln", sim_pln) +sim_model = sim_model_2cov + sim_model_0cov +dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_model", sim_model) -all_pln = real_pln + sim_pln + instances +all_model = real_model + sim_model + instances dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "instances", instances) -dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "all_pln", all_pln) +dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "all_model", all_model) -for string_fixture in all_pln: +for string_fixture in all_model: print("string_fixture", string_fixture) dict_fixtures = add_fixture_to_dict(dict_fixtures, string_fixture) diff --git a/tests/create_readme_and_docstrings_tests.py b/tests/create_readme_getting_started_and_docstrings_tests.py similarity index 72% rename from tests/create_readme_and_docstrings_tests.py rename to tests/create_readme_getting_started_and_docstrings_tests.py index d9f27aebc018b7b277e3fa04b63891f607cc5fd7..113bf841c8cb6f5723069d8e73639efb876f4cb0 100644 --- a/tests/create_readme_and_docstrings_tests.py +++ b/tests/create_readme_getting_started_and_docstrings_tests.py @@ -4,6 +4,7 @@ import os dir_docstrings = "docstrings_examples" dir_readme = "readme_examples" +dir_getting_started = "getting_started" def get_lines(path_to_file, filename, filetype=".py"): @@ -43,15 +44,15 @@ def get_example_readme(lines): in_example = False elif in_example is True: example.append(line) - example.pop(0) # The first is pip install pyPLNmodels which is not python code. + example.pop() # The last line is pip install pyPLNmodels which is not python code. return [example] -def write_examples(examples, filename): +def write_file(examples, filename, string_definer, dir): for i in range(len(examples)): example = examples[i] nb_example = str(i + 1) - example_filename = f"test_{filename}_example_{nb_example}.py" + example_filename = f"{dir}/test_{filename}_{string_definer}_{nb_example}.py" try: os.remove(example_filename) except FileNotFoundError: @@ -64,19 +65,34 @@ def write_examples(examples, filename): def filename_to_docstring_example_file(filename, dirname): lines = get_lines("../pyPLNmodels/", filename) examples = get_examples_docstring(lines) - write_examples(examples, filename) + write_file(examples, filename, "example", dir=dirname) def filename_to_readme_example_file(): lines = get_lines("../", "README", filetype=".md") examples = get_example_readme(lines) - write_examples(examples, "readme") + write_file(examples, "readme", "example", dir=dir_readme) + + +lines_getting_started = get_lines("./", "test_getting_started") +new_lines = [] +for line in lines_getting_started: + if len(line) > 20: + if line[0:11] != "get_ipython": + new_lines.append(line) + else: + new_lines.append(line) os.makedirs(dir_readme, exist_ok=True) +os.makedirs(dir_docstrings, exist_ok=True) +os.makedirs(dir_getting_started, exist_ok=True) + +write_file([new_lines], "getting_started", "", dir_getting_started) + filename_to_readme_example_file() -os.makedirs("docstrings_examples", exist_ok=True) + filename_to_docstring_example_file("_utils", dir_docstrings) filename_to_docstring_example_file("models", dir_docstrings) filename_to_docstring_example_file("elbos", dir_docstrings) diff --git a/tests/import_data.py b/tests/import_data.py index 9ef5ef7e0219ce836a9c8b86dd021c7b68ba436d..9942db40f966bcfd0ed41a3ddfcc131b67b6d5b9 100644 --- a/tests/import_data.py +++ b/tests/import_data.py @@ -1,10 +1,15 @@ import os +import torch from pyPLNmodels import ( get_simulated_count_data, get_real_count_data, ) +if torch.cuda.is_available(): + DEVICE = "cuda:0" +else: + DEVICE = "cpu" ( endog_sim_0cov, diff --git a/tests/run_readme_docstrings_getting_started.sh b/tests/run_readme_docstrings_getting_started.sh new file mode 100755 index 0000000000000000000000000000000000000000..594890339b63c4144b6b9b4d69ce630e98c470d8 --- /dev/null +++ b/tests/run_readme_docstrings_getting_started.sh @@ -0,0 +1,15 @@ +#!/bin/sh +for file in docstrings_examples/* +do + python $file +done + +for file in readme_examples/* +do + python $file +done + +for file in getting_started/* +do + python $file +done diff --git a/tests/test_common.py b/tests/test_common.py index b1a6837cbd21a602496ba1fca6cddd4c506aeab7..bd5ca62cfcc7054651a02015033e38958336f317 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -8,82 +8,90 @@ from tests.utils import MSE, filter_models from tests.import_data import true_sim_0cov, true_sim_2cov, endog_real +pln_and_plnpca = ["Pln", "PlnPCA"] +single_models = ["Pln", "PlnPCA", "ZIPln"] -@pytest.mark.parametrize("any_pln", dict_fixtures["loaded_and_fitted_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_properties(any_pln): - assert hasattr(any_pln, "latent_parameters") - assert hasattr(any_pln, "latent_variables") - assert hasattr(any_pln, "optim_parameters") - assert hasattr(any_pln, "model_parameters") +@pytest.mark.parametrize("any_model", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(single_models) +def test_properties(any_model): + assert hasattr(any_model, "latent_parameters") + assert hasattr(any_model, "latent_variables") + assert hasattr(any_model, "optim_parameters") + assert hasattr(any_model, "model_parameters") -@pytest.mark.parametrize("sim_pln", dict_fixtures["loaded_and_fitted_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_predict_simulated(sim_pln): - if sim_pln.nb_cov == 0: - assert sim_pln.predict() is None + +@pytest.mark.parametrize("sim_model", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(pln_and_plnpca) +def test_predict_simulated(sim_model): + if sim_model.nb_cov == 0: + assert sim_model.predict() is None with pytest.raises(AttributeError): - sim_pln.predict(1) + sim_model.predict(1) else: - X = torch.randn((sim_pln.n_samples, sim_pln.nb_cov)) - prediction = sim_pln.predict(X) - expected = X @ sim_pln.coef + X = torch.randn((sim_model.n_samples, sim_model.nb_cov)) + prediction = sim_model.predict(X) + expected = X @ sim_model.coef assert torch.all(torch.eq(expected, prediction)) -@pytest.mark.parametrize("any_instance_pln", dict_fixtures["instances"]) -def test_verbose(any_instance_pln): - any_instance_pln.fit(verbose=True, tol=0.1) +@pytest.mark.parametrize("any_instance_model", dict_fixtures["instances"]) +def test_verbose(any_instance_model): + any_instance_model.fit(verbose=True, tol=0.1) @pytest.mark.parametrize( - "simulated_fitted_any_pln", dict_fixtures["loaded_and_fitted_sim_pln"] + "simulated_fitted_any_model", dict_fixtures["loaded_and_fitted_sim_model"] ) -@filter_models(["Pln", "PlnPCA"]) -def test_find_right_covariance(simulated_fitted_any_pln): - if simulated_fitted_any_pln.nb_cov == 0: - true_covariance = true_sim_0cov["Sigma"] - elif simulated_fitted_any_pln.nb_cov == 2: - true_covariance = true_sim_2cov["Sigma"] +@filter_models(pln_and_plnpca) +def test_find_right_covariance(simulated_fitted_any_model): + if simulated_fitted_any_model.nb_cov == 0: + true_covariance = true_sim_0cov["Sigma"].cpu() + elif simulated_fitted_any_model.nb_cov == 2: + true_covariance = true_sim_2cov["Sigma"].cpu() else: raise ValueError( - f"Not the right numbers of covariance({simulated_fitted_any_pln.nb_cov})" + f"Not the right numbers of covariance({simulated_fitted_any_model.nb_cov})" ) - mse_covariance = MSE(simulated_fitted_any_pln.covariance - true_covariance) + mse_covariance = MSE( + simulated_fitted_any_model.covariance.cpu() - true_covariance.cpu() + ) assert mse_covariance < 0.05 @pytest.mark.parametrize( - "real_fitted_and_loaded_pln", dict_fixtures["loaded_and_fitted_real_pln"] + "real_fitted_and_loaded_model", dict_fixtures["loaded_and_fitted_real_model"] ) -@filter_models(["Pln", "PlnPCA"]) -def test_right_covariance_shape(real_fitted_and_loaded_pln): - assert real_fitted_and_loaded_pln.covariance.shape == ( +@filter_models(single_models) +def test_right_covariance_shape(real_fitted_and_loaded_model): + assert real_fitted_and_loaded_model.covariance.shape == ( endog_real.shape[1], endog_real.shape[1], ) @pytest.mark.parametrize( - "simulated_fitted_any_pln", dict_fixtures["loaded_and_fitted_pln"] + "simulated_fitted_any_model", dict_fixtures["loaded_and_fitted_model"] ) -@filter_models(["Pln", "PlnPCA"]) -def test_find_right_coef(simulated_fitted_any_pln): - if simulated_fitted_any_pln.nb_cov == 2: +@filter_models(pln_and_plnpca) +def test_find_right_coef(simulated_fitted_any_model): + if simulated_fitted_any_model.nb_cov == 2: true_coef = true_sim_2cov["beta"] - mse_coef = MSE(simulated_fitted_any_pln.coef - true_coef) + mse_coef = MSE(simulated_fitted_any_model.coef.cpu() - true_coef.cpu()) assert mse_coef < 0.1 - elif simulated_fitted_any_pln.nb_cov == 0: - assert simulated_fitted_any_pln.coef is None + elif simulated_fitted_any_model.nb_cov == 0: + assert simulated_fitted_any_model.coef is None -@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_fail_count_setter(pln): +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(single_models) +def test_fail_count_setter(model): wrong_endog = torch.randint(size=(10, 5), low=0, high=10) - with pytest.raises(Exception): - pln.endog = wrong_endog + negative_endog = -model._endog + with pytest.raises(ValueError): + model.endog = wrong_endog + with pytest.raises(ValueError): + model.endog = negative_endog @pytest.mark.parametrize("instance", dict_fixtures["instances"]) @@ -96,9 +104,23 @@ def test__print_end_of_fitting_message(instance): instance.fit(nb_max_iteration=4) -@pytest.mark.parametrize("pln", dict_fixtures["fitted_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_fail_wrong_exog_prediction(pln): - X = torch.randn(pln.n_samples, pln.nb_cov + 1) +@pytest.mark.parametrize("model", dict_fixtures["fitted_model"]) +@filter_models(single_models) +def test_fail_wrong_exog_prediction(model): + X = torch.randn(model.n_samples, model.nb_cov + 1) with pytest.raises(Exception): - pln.predict(X) + model.predict(X) + + +@pytest.mark.parametrize("model", dict_fixtures["sim_model_instance"]) +@filter_models(pln_and_plnpca) +def test_batch(model): + model.fit(batch_size=20) + print(model) + model.show() + if model.nb_cov == 2: + true_coef = true_sim_2cov["beta"] + mse_coef = MSE(model.coef.cpu() - true_coef.cpu()) + assert mse_coef < 0.1 + elif model.nb_cov == 0: + assert model.coef is None diff --git a/tests/test_getting_started.py b/tests/test_getting_started.py deleted file mode 100644 index 69299741896710896384734d31e49a9e2e786ebf..0000000000000000000000000000000000000000 --- a/tests/test_getting_started.py +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 - -# get_ipython().system('pip install pyPLNmodels') - - -# ## pyPLNmodels - -# We assume the data comes from a PLN model: $ \text{counts} \sim \mathcal P(\exp(\text{Z}))$, where $Z$ are some unknown latent variables. -# -# -# The goal of the package is to retrieve the latent variables $Z$ given the counts. To do so, one can instantiate a Pln or PlnPCA model, fit it and then extract the latent variables. - -# ### Import the needed functions - -from pyPLNmodels import ( - get_real_count_data, - get_simulated_count_data, - load_model, - Pln, - PlnPCA, - PlnPCAcollection, -) -import matplotlib.pyplot as plt - - -# ### Load the data - -counts, labels = get_real_count_data(return_labels=True) # np.ndarray - - -# ### PLN model - -pln = Pln(counts, add_const=True) -pln.fit() - - -print(pln) - - -# #### Once fitted, we can extract multiple variables: - -gaussian = pln.latent_variables -print(gaussian.shape) - - -model_param = pln.model_parameters -print(model_param["coef"].shape) -print(model_param["covariance"].shape) - - -# ### PlnPCA model - -pca = PlnPCA(counts, add_const=True, rank=5) -pca.fit() - - -print(pca) - - -print(pca.latent_variables.shape) - - -print(pca.model_parameters["components"].shape) -print(pca.model_parameters["coef"].shape) - - -# ### One can save the model in order to load it back after: - -pca.save() -dict_init = load_model("PlnPCA_nbcov_1_dim_200_rank_5") -loaded_pca = PlnPCA(counts, add_const=True, dict_initialization=dict_init) -print(loaded_pca) - - -# ### One can fit multiple PCA and choose the best rank with BIC or AIC criterion - -pca_col = PlnPCAcollection(counts, add_const=True, ranks=[5, 15, 25, 40, 50]) -pca_col.fit() - - -pca_col.show() - - -print(pca_col) - - -# ### One can extract the best model found (according to AIC or BIC criterion). - -# #### AIC best model - -print(pca_col.best_model(criterion="AIC")) - - -# #### BIC best model - -print(pca_col.best_model(criterion="BIC")) - - -# #### Visualization of the individuals (sites) with PCA on the latent variables. - -pln.viz(colors=labels) -plt.show() - - -best_pca = pca_col.best_model() -best_pca.viz(colors=labels) -plt.show() - - -# ### What would give a PCA on the log normalize data ? - -from sklearn.decomposition import PCA -import numpy as np -import seaborn as sns - - -sk_pca = PCA(n_components=2) -pca_log_counts = sk_pca.fit_transform(np.log(counts + (counts == 0))) -sns.scatterplot(x=pca_log_counts[:, 0], y=pca_log_counts[:, 1], hue=labels) - - -# ### Visualization of the variables - -pln.plot_pca_correlation_graph(["var_1", "var_2"], indices_of_variables=[0, 1]) -plt.show() - - -best_pca.plot_pca_correlation_graph(["var_1", "var_2"], indices_of_variables=[0, 1]) -plt.show() - - -# ### Visualization of each components of the PCA -# - -pln.scatter_pca_matrix(color=labels, n_components=5) -plt.show() - - -best_pca.scatter_pca_matrix(color=labels, n_components=6) -plt.show() diff --git a/tests/test_pln_full.py b/tests/test_pln_full.py index 2d61befd5e2e529c4ec9022910d403c9e301cd32..e5959b0eeadd0f2e3180577d6d4caf073ca4b511 100644 --- a/tests/test_pln_full.py +++ b/tests/test_pln_full.py @@ -4,14 +4,14 @@ from tests.conftest import dict_fixtures from tests.utils import filter_models -@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"]) +@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_model"]) @filter_models(["Pln"]) def test_number_of_iterations_pln_full(fitted_pln): nb_iterations = len(fitted_pln._elbos_list) - assert 50 < nb_iterations < 500 + assert 20 < nb_iterations < 2000 -@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"]) +@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_model"]) @filter_models(["Pln"]) -def test_latent_var_full(pln): +def test_latent_variables(pln): assert pln.transform().shape == pln.endog.shape diff --git a/tests/test_plnpcacollection.py b/tests/test_plnpcacollection.py index 6634f2d2eba5f61520d7a324ea08026233bcc9d4..761afabc3f27d7d4b601ff53abc9185b112bd6cc 100644 --- a/tests/test_plnpcacollection.py +++ b/tests/test_plnpcacollection.py @@ -6,16 +6,17 @@ import numpy as np from tests.conftest import dict_fixtures from tests.utils import MSE, filter_models +from tests.import_data import true_sim_0cov, true_sim_2cov -@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"]) +@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"]) @filter_models(["PlnPCAcollection"]) def test_best_model(plnpca): best_model = plnpca.best_model() print(best_model) -@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"]) +@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"]) @filter_models(["PlnPCAcollection"]) def test_projected_variables(plnpca): best_model = plnpca.best_model() @@ -23,21 +24,20 @@ def test_projected_variables(plnpca): assert plv.shape[0] == best_model.n_samples and plv.shape[1] == best_model.rank -@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"]) -@filter_models(["PlnPCA"]) -def test_number_of_iterations_plnpca(fitted_pln): - nb_iterations = len(fitted_pln._elbos_list) - assert 100 < nb_iterations < 5000 +@pytest.mark.parametrize("plnpca", dict_fixtures["sim_model_instance"]) +@filter_models(["PlnPCAcollection"]) +def test_right_nbcov(plnpca): + assert plnpca.nb_cov == 0 or plnpca.nb_cov == 2 -@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"]) +@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"]) @filter_models(["PlnPCA"]) def test_latent_var_pca(plnpca): - assert plnpca.transform(project=False).shape == plnpca.endog.shape - assert plnpca.transform().shape == (plnpca.n_samples, plnpca.rank) + assert plnpca.transform().shape == plnpca.endog.shape + assert plnpca.transform(project=True).shape == (plnpca.n_samples, plnpca.rank) -@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"]) +@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"]) @filter_models(["PlnPCAcollection"]) def test_additional_methods_pca(plnpca): plnpca.show() @@ -46,14 +46,14 @@ def test_additional_methods_pca(plnpca): plnpca.loglikes -@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"]) +@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"]) @filter_models(["PlnPCAcollection"]) def test_wrong_criterion(plnpca): with pytest.raises(ValueError): plnpca.best_model("AIK") -@pytest.mark.parametrize("collection", dict_fixtures["loaded_and_fitted_pln"]) +@pytest.mark.parametrize("collection", dict_fixtures["loaded_and_fitted_model"]) @filter_models(["PlnPCAcollection"]) def test_item(collection): print(collection[collection.ranks[0]]) @@ -62,3 +62,28 @@ def test_item(collection): assert collection.ranks[0] in collection assert collection.ranks[0] in list(collection.keys()) collection.get(collection.ranks[0], None) + + +@pytest.mark.parametrize("collection", dict_fixtures["sim_model_instance"]) +@filter_models(["PlnPCAcollection"]) +def test_batch(collection): + collection.fit(batch_size=20) + assert collection.nb_cov == 0 or collection.nb_cov == 2 + if collection.nb_cov == 0: + true_covariance = true_sim_0cov["Sigma"] + for model in collection.values(): + assert model.coef is None + true_coef = None + elif collection.nb_cov == 2: + true_covariance = true_sim_2cov["Sigma"] + true_coef = true_sim_2cov["beta"] + else: + raise ValueError(f"Not the right numbers of covariance({collection.nb_cov})") + for model in collection.values(): + mse_covariance = MSE(model.covariance.cpu() - true_covariance.cpu()) + if true_coef is not None: + mse_coef = MSE(model.coef.cpu() - true_coef.cpu()) + assert mse_coef < 0.35 + assert mse_covariance < 0.25 + collection.fit() + assert collection.batch_size == 20 diff --git a/tests/test_setters.py b/tests/test_setters.py index 828989e81fb3483cfa3d1963af8c9d1ff1e0e2e3..f230d858db207764974518a17e45f40f8cc5961f 100644 --- a/tests/test_setters.py +++ b/tests/test_setters.py @@ -5,144 +5,169 @@ import torch from tests.conftest import dict_fixtures from tests.utils import MSE, filter_models - -@pytest.mark.parametrize("pln", dict_fixtures["all_pln"]) -def test_data_setter_with_torch(pln): - pln.endog = pln.endog - pln.exog = pln.exog - pln.offsets = pln.offsets - pln.fit() - - -@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_parameters_setter_with_torch(pln): - pln.latent_mean = pln.latent_mean - pln.latent_sqrt_var = pln.latent_sqrt_var - pln.coef = pln.coef - if pln._NAME == "PlnPCA": - pln.components = pln.components - pln.fit() - - -@pytest.mark.parametrize("pln", dict_fixtures["all_pln"]) -def test_data_setter_with_numpy(pln): - np_endog = pln.endog.numpy() - if pln.exog is not None: - np_exog = pln.exog.numpy() +single_models = ["Pln", "PlnPCA", "ZIPln"] + + +@pytest.mark.parametrize("model", dict_fixtures["loaded_model"]) +def test_data_setter_with_torch(model): + model.endog = model.endog + model.exog = model.exog + model.offsets = model.offsets + model.fit() + + +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(single_models) +def test_parameters_setter_with_torch(model): + model.latent_mean = model.latent_mean + model.latent_sqrt_var = model.latent_sqrt_var + if model._NAME != "Pln": + model.coef = model.coef + if model._NAME == "PlnPCA" or model._NAME == "ZIPln": + model.components = model.components + if model._NAME == "ZIPln": + model.coef_inflation = model.coef_inflation + model.fit() + + +@pytest.mark.parametrize("model", dict_fixtures["all_model"]) +def test_data_setter_with_numpy(model): + np_endog = model.endog.numpy() + if model.exog is not None: + np_exog = model.exog.numpy() else: np_exog = None - np_offsets = pln.offsets.numpy() - pln.endog = np_endog - pln.exog = np_exog - pln.offsets = np_offsets - pln.fit() - - -@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_parameters_setter_with_numpy(pln): - np_latent_mean = pln.latent_mean.numpy() - np_latent_sqrt_var = pln.latent_sqrt_var.numpy() - if pln.coef is not None: - np_coef = pln.coef.numpy() + np_offsets = model.offsets.numpy() + model.endog = np_endog + model.exog = np_exog + model.offsets = np_offsets + model.fit() + + +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(single_models) +def test_parameters_setter_with_numpy(model): + np_latent_mean = model.latent_mean.numpy() + np_latent_sqrt_var = model.latent_sqrt_var.numpy() + if model.coef is not None: + np_coef = model.coef.numpy() else: np_coef = None - pln.latent_mean = np_latent_mean - pln.latent_sqrt_var = np_latent_sqrt_var - pln.coef = np_coef - if pln._NAME == "PlnPCA": - pln.components = pln.components.numpy() - pln.fit() - - -@pytest.mark.parametrize("pln", dict_fixtures["all_pln"]) -def test_data_setter_with_pandas(pln): - pd_endog = pd.DataFrame(pln.endog.numpy()) - if pln.exog is not None: - pd_exog = pd.DataFrame(pln.exog.numpy()) + model.latent_mean = np_latent_mean + model.latent_sqrt_var = np_latent_sqrt_var + if model._NAME != "Pln": + model.coef = np_coef + if model._NAME == "PlnPCA" or model._NAME == "ZIPln": + model.components = model.components.numpy() + if model._NAME == "ZIPln": + model.coef_inflation = model.coef_inflation.numpy() + model.fit() + + +@pytest.mark.parametrize("model", dict_fixtures["all_model"]) +def test_batch_size_setter(model): + model.batch_size = 20 + model.fit(nb_max_iteration=3) + assert model.batch_size == 20 + + +@pytest.mark.parametrize("model", dict_fixtures["all_model"]) +def test_fail_batch_size_setter(model): + with pytest.raises(ValueError): + model.batch_size = model.n_samples + 1 + + +@pytest.mark.parametrize("model", dict_fixtures["all_model"]) +def test_data_setter_with_pandas(model): + pd_endog = pd.DataFrame(model.endog.numpy()) + if model.exog is not None: + pd_exog = pd.DataFrame(model.exog.numpy()) else: pd_exog = None - pd_offsets = pd.DataFrame(pln.offsets.numpy()) - pln.endog = pd_endog - pln.exog = pd_exog - pln.offsets = pd_offsets - pln.fit() - - -@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_parameters_setter_with_pandas(pln): - pd_latent_mean = pd.DataFrame(pln.latent_mean.numpy()) - pd_latent_sqrt_var = pd.DataFrame(pln.latent_sqrt_var.numpy()) - if pln.coef is not None: - pd_coef = pd.DataFrame(pln.coef.numpy()) + pd_offsets = pd.DataFrame(model.offsets.numpy()) + model.endog = pd_endog + model.exog = pd_exog + model.offsets = pd_offsets + model.fit() + + +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(single_models) +def test_parameters_setter_with_pandas(model): + pd_latent_mean = pd.DataFrame(model.latent_mean.numpy()) + pd_latent_sqrt_var = pd.DataFrame(model.latent_sqrt_var.numpy()) + if model.coef is not None: + pd_coef = pd.DataFrame(model.coef.numpy()) else: pd_coef = None - pln.latent_mean = pd_latent_mean - pln.latent_sqrt_var = pd_latent_sqrt_var - pln.coef = pd_coef - if pln._NAME == "PlnPCA": - pln.components = pd.DataFrame(pln.components.numpy()) - pln.fit() - - -@pytest.mark.parametrize("pln", dict_fixtures["all_pln"]) -def test_fail_data_setter_with_torch(pln): + model.latent_mean = pd_latent_mean + model.latent_sqrt_var = pd_latent_sqrt_var + if model._NAME != "Pln": + model.coef = pd_coef + if model._NAME == "PlnPCA": + model.components = pd.DataFrame(model.components.numpy()) + if model._NAME == "ZIPln": + model.coef_inflation = pd.DataFrame(model.coef_inflation.numpy()) + model.fit() + + +@pytest.mark.parametrize("model", dict_fixtures["all_model"]) +def test_fail_data_setter_with_torch(model): with pytest.raises(ValueError): - pln.endog = pln.endog - 100 + model.endog = -model.endog - n, p = pln.endog.shape - if pln.exog is None: + n, p = model.endog.shape + if model.exog is None: d = 0 else: - d = pln.exog.shape[-1] + d = model.exog.shape[-1] with pytest.raises(ValueError): - pln.endog = torch.zeros(n + 1, p) + model.endog = torch.zeros(n + 1, p) with pytest.raises(ValueError): - pln.endog = torch.zeros(n, p + 1) + model.endog = torch.zeros(n, p + 1) with pytest.raises(ValueError): - pln.exog = torch.zeros(n + 1, d) + model.exog = torch.zeros(n + 1, d) with pytest.raises(ValueError): - pln.offsets = torch.zeros(n + 1, p) + model.offsets = torch.zeros(n + 1, p) with pytest.raises(ValueError): - pln.offsets = torch.zeros(n, p + 1) + model.offsets = torch.zeros(n, p + 1) -@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_fail_parameters_setter_with_torch(pln): - n, dim_latent = pln.latent_mean.shape - dim = pln.endog.shape[1] +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(single_models) +def test_fail_parameters_setter_with_torch(model): + n, dim_latent = model.latent_mean.shape + dim = model.endog.shape[1] with pytest.raises(ValueError): - pln.latent_mean = torch.zeros(n + 1, dim_latent) + model.latent_mean = torch.zeros(n + 1, dim_latent) with pytest.raises(ValueError): - pln.latent_mean = torch.zeros(n, dim_latent + 1) + model.latent_mean = torch.zeros(n, dim_latent + 1) with pytest.raises(ValueError): - pln.latent_sqrt_var = torch.zeros(n + 1, dim_latent) + model.latent_sqrt_var = torch.zeros(n + 1, dim_latent) with pytest.raises(ValueError): - pln.latent_sqrt_var = torch.zeros(n, dim_latent + 1) + model.latent_sqrt_var = torch.zeros(n, dim_latent + 1) - if pln._NAME == "PlnPCA": + if model._NAME == "PlnPCA": with pytest.raises(ValueError): - pln.components = torch.zeros(dim, dim_latent + 1) + model.components = torch.zeros(dim, dim_latent + 1) with pytest.raises(ValueError): - pln.components = torch.zeros(dim + 1, dim_latent) + model.components = torch.zeros(dim + 1, dim_latent) - if pln.exog is None: + if model.exog is None: d = 0 else: - d = pln.exog.shape[-1] - with pytest.raises(ValueError): - pln.coef = torch.zeros(d + 1, dim) + d = model.exog.shape[-1] + if model._NAME != "Pln": + with pytest.raises(ValueError): + model.coef = torch.zeros(d + 1, dim) - with pytest.raises(ValueError): - pln.coef = torch.zeros(d, dim + 1) + with pytest.raises(ValueError): + model.coef = torch.zeros(d, dim + 1) diff --git a/tests/test_viz.py b/tests/test_viz.py index be24fcf141426851bda58431c071d1c1f9746ab0..d4f9a7383547a3cc14b4db1ebca694840e9ae91a 100644 --- a/tests/test_viz.py +++ b/tests/test_viz.py @@ -7,47 +7,49 @@ from tests.utils import MSE, filter_models from tests.import_data import true_sim_0cov, true_sim_2cov, labels_real +single_models = ["Pln", "PlnPCA", "ZIPln"] -@pytest.mark.parametrize("any_pln", dict_fixtures["loaded_and_fitted_pln"]) -def test_print(any_pln): - print(any_pln) - - -@pytest.mark.parametrize("any_pln", dict_fixtures["fitted_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_show_coef_transform_covariance_pcaprojected(any_pln): - any_pln.show() - any_pln._plotargs._show_loss() - any_pln._plotargs._show_stopping_criterion() - assert hasattr(any_pln, "coef") - assert callable(any_pln.transform) - assert hasattr(any_pln, "covariance") - assert callable(any_pln.sk_PCA) - assert any_pln.sk_PCA(n_components=None) is not None + +@pytest.mark.parametrize("any_model", dict_fixtures["loaded_and_fitted_model"]) +def test_print(any_model): + print(any_model) + + +@pytest.mark.parametrize("any_model", dict_fixtures["fitted_model"]) +@filter_models(single_models) +def test_show_coef_transform_covariance_pcaprojected(any_model): + any_model.show() + any_model._criterion_args._show_loss() + any_model._criterion_args._show_stopping_criterion() + assert hasattr(any_model, "coef") + assert callable(any_model.transform) + assert hasattr(any_model, "covariance") + assert callable(any_model.sk_PCA) + assert any_model.sk_PCA(n_components=None) is not None with pytest.raises(Exception): - any_pln.sk_PCA(n_components=any_pln.dim + 1) + any_model.sk_PCA(n_components=any_model.dim + 1) -@pytest.mark.parametrize("pln", dict_fixtures["fitted_pln"]) -@filter_models(["Pln"]) -def test_scatter_pca_matrix_pln(pln): - pln.scatter_pca_matrix(n_components=8) +@pytest.mark.parametrize("model", dict_fixtures["fitted_model"]) +@filter_models(["Pln", "ZIPln"]) +def test_scatter_pca_matrix_pln(model): + model.scatter_pca_matrix(n_components=8) -@pytest.mark.parametrize("pln", dict_fixtures["fitted_pln"]) +@pytest.mark.parametrize("model", dict_fixtures["fitted_model"]) @filter_models(["PlnPCA"]) -def test_scatter_pca_matrix_plnpca(pln): - pln.scatter_pca_matrix(n_components=2) - pln.scatter_pca_matrix() +def test_scatter_pca_matrix_plnpca(model): + model.scatter_pca_matrix(n_components=2) + model.scatter_pca_matrix() -@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_real_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_label_scatter_pca_matrix(pln): - pln.scatter_pca_matrix(n_components=4, color=labels_real) +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_real_model"]) +@filter_models(single_models) +def test_label_scatter_pca_matrix(model): + model.scatter_pca_matrix(n_components=4, color=labels_real) -@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"]) +@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_model"]) @filter_models(["PlnPCAcollection"]) def test_viz_pcacol(plnpca): for model in plnpca.values(): @@ -64,38 +66,38 @@ def test_viz_pcacol(plnpca): plt.show() -@pytest.mark.parametrize("pln", dict_fixtures["real_fitted_pln_intercept_array"]) -@filter_models(["Pln", "PlnPCA"]) -def test_plot_pca_correlation_graph_with_names_only(pln): - pln.plot_pca_correlation_graph([f"var_{i}" for i in range(8)]) +@pytest.mark.parametrize("model", dict_fixtures["real_fitted_model_intercept_array"]) +@filter_models(single_models) +def test_plot_pca_correlation_graph_with_names_only(model): + model.plot_pca_correlation_graph([f"var_{i}" for i in range(8)]) -@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_sim_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_fail_plot_pca_correlation_graph_without_names(pln): +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_sim_model"]) +@filter_models(single_models) +def test_fail_plot_pca_correlation_graph_without_names(model): with pytest.raises(ValueError): - pln.plot_pca_correlation_graph([f"var_{i}" for i in range(8)]) + model.plot_pca_correlation_graph([f"var_{i}" for i in range(8)]) with pytest.raises(ValueError): - pln.plot_pca_correlation_graph([f"var_{i}" for i in range(6)], [1, 2, 3]) + model.plot_pca_correlation_graph([f"var_{i}" for i in range(6)], [1, 2, 3]) -@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_sim_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_plot_pca_correlation_graph_without_names(pln): - pln.plot_pca_correlation_graph([f"var_{i}" for i in range(3)], [0, 1, 2]) +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_sim_model"]) +@filter_models(single_models) +def test_plot_pca_correlation_graph_without_names(model): + model.plot_pca_correlation_graph([f"var_{i}" for i in range(3)], [0, 1, 2]) -@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_expected_vs_true(pln): - pln.plot_expected_vs_true() +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(single_models) +def test_expected_vs_true(model): + model.plot_expected_vs_true() fig, ax = plt.subplots() - pln.plot_expected_vs_true(ax=ax) + model.plot_expected_vs_true(ax=ax) -@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_real_pln"]) -@filter_models(["Pln", "PlnPCA"]) -def test_expected_vs_true_labels(pln): - pln.plot_expected_vs_true(colors=labels_real) +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_real_model"]) +@filter_models(single_models) +def test_expected_vs_true_labels(model): + model.plot_expected_vs_true(colors=labels_real) fig, ax = plt.subplots() - pln.plot_expected_vs_true(ax=ax, colors=labels_real) + model.plot_expected_vs_true(ax=ax, colors=labels_real) diff --git a/tests/test_zi.py b/tests/test_zi.py new file mode 100644 index 0000000000000000000000000000000000000000..acfaa5bd256e7b0cb55db0f407bc045c0b9fbbff --- /dev/null +++ b/tests/test_zi.py @@ -0,0 +1,129 @@ +import pytest +import torch + +from pyPLNmodels import get_simulation_parameters, sample_pln, ZIPln +from tests.conftest import dict_fixtures +from tests.utils import filter_models, MSE + + +from pyPLNmodels import get_simulated_count_data + + +@pytest.mark.parametrize("zi", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(["ZIPln"]) +def test_properties(zi): + assert hasattr(zi, "latent_prob") + assert hasattr(zi, "coef_inflation") + + +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(["ZIPln"]) +def test_predict(model): + X = torch.randn((model.n_samples, model.nb_cov)) + prediction = model.predict(X) + expected = X @ model.coef + assert torch.all(torch.eq(expected, prediction)) + + +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(["ZIPln"]) +def test_predict_prob(model): + X = torch.randn((model.n_samples, model.nb_cov)) + prediction = model.predict_prob_inflation(X) + expected = torch.sigmoid(X @ model.coef_inflation) + assert torch.all(torch.eq(expected, prediction)) + + +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(["ZIPln"]) +def test_fail_predict_prob(model): + X1 = torch.randn((model.n_samples, model.nb_cov + 1)) + X2 = torch.randn((model.n_samples, model.nb_cov - 1)) + with pytest.raises(RuntimeError): + model.predict_prob_inflation(X1) + with pytest.raises(RuntimeError): + model.predict_prob_inflation(X2) + + +@pytest.mark.parametrize("model", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(["ZIPln"]) +def test_fail_predict(model): + X1 = torch.randn((model.n_samples, model.nb_cov + 1)) + X2 = torch.randn((model.n_samples, model.nb_cov - 1)) + with pytest.raises(RuntimeError): + model.predict(X1) + with pytest.raises(RuntimeError): + model.predict(X2) + + +@pytest.mark.parametrize("model", dict_fixtures["sim_model_0cov_fitted_and_loaded"]) +@filter_models(["ZIPln"]) +def test_no_exog_not_possible(model): + assert model.nb_cov == 1 + assert model._coef_inflation.shape[0] == 1 + + +def test_find_right_covariance_coef_and_infla(): + pln_param = get_simulation_parameters(zero_inflated=True, n_samples=1000) + # pln_param._coef += 5 + endog = sample_pln(pln_param, seed=0, return_latent=False) + exog = pln_param.exog + offsets = pln_param.offsets + covariance = pln_param.covariance + coef = pln_param.coef + coef_inflation = pln_param.coef_inflation + endog, exog, offsets, covariance, coef, coef_inflation = get_simulated_count_data( + zero_inflated=True, return_true_param=True, n_samples=1000 + ) + zi = ZIPln(endog, exog=exog, offsets=offsets, use_closed_form_prob=False) + zi.fit() + mse_covariance = MSE(zi.covariance.cpu() - covariance.cpu()) + mse_coef = MSE(zi.coef.cpu() - coef.cpu()) + mse_coef_infla = MSE(zi.coef_inflation.cpu() - coef_inflation.cpu()) + assert mse_coef < 3 + assert mse_coef_infla < 3 + assert mse_covariance < 1 + + +@pytest.mark.parametrize("zi", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(["ZIPln"]) +def test_latent_variables(zi): + z, w = zi.latent_variables + assert z.shape == zi.endog.shape + assert w.shape == zi.endog.shape + + +@pytest.mark.parametrize("zi", dict_fixtures["loaded_and_fitted_model"]) +@filter_models(["ZIPln"]) +def test_transform(zi): + z = zi.transform() + assert z.shape == zi.endog.shape + z, w = zi.transform(return_latent_prob=True) + assert z.shape == w.shape == zi.endog.shape + + +@pytest.mark.parametrize("model", dict_fixtures["sim_model_instance"]) +@filter_models(["ZIPln"]) +def test_batch(model): + pln_param = get_simulation_parameters(zero_inflated=True, n_samples=1000) + # pln_param._coef += 5 + endog = sample_pln(pln_param, seed=0, return_latent=False) + exog = pln_param.exog + offsets = pln_param.offsets + covariance = pln_param.covariance + coef = pln_param.coef + coef_inflation = pln_param.coef_inflation + endog, exog, offsets, covariance, coef, coef_inflation = get_simulated_count_data( + zero_inflated=True, return_true_param=True, n_samples=1000 + ) + zi = ZIPln(endog, exog=exog, offsets=offsets, use_closed_form_prob=False) + zi.fit(batch_size=20) + mse_covariance = MSE(zi.covariance.cpu() - covariance.cpu()) + mse_coef = MSE(zi.coef.cpu() - coef.cpu()) + mse_coef_infla = MSE(zi.coef_inflation.cpu() - coef_inflation.cpu()) + assert mse_coef < 3 + assert mse_coef_infla < 3 + assert mse_covariance < 1 + zi.show() + print(zi) + zi.fit()