I need to differentiate between 4 different sets of probability distributions from the scipy.stats
module: generic univariate, frozen univariate, generic multivariate, & frozen multivariate. Throughout the application, I would like to add type hints for these 4 sets.
For the univariate cases, mypy has no problems with type hints like in this MWE:
from typing import Any
from scipy.stats import norm
from scipy.stats._distn_infrastructure import rv_generic, rv_frozen
def sample_frozen_univariate(n_samples: int, dist: rv_frozen):
return dist.rvs(n_samples)
def sample_generic_univariate(n_samples: int, dist: rv_generic,
*distparams: Any):
return dist(*distparams).rvs(n_samples)
n_samples = 4
frozen_dist = norm()
print(sample_frozen_univariate(n_samples, frozen_dist))
n_samples = 4
generic_dist = norm
loc, scale = 1, 2
print(sample_generic_univariate(n_samples, generic_dist, loc, scale))
Now here's an MWE of the multivariate cases. If executed it provides the correct result:
from typing import Any
from scipy.stats import multivariate_normal
from scipy.stats._multivariate import multi_rv_generic, multi_rv_frozen
def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: multi_rv_frozen):
if dist.dim != n_variates:
msg = 'distribution dimension %s inconsistent with n_variates=%s'
raise ValueError(msg % (dist.dim, n_variates))
sample = dist.rvs(n_samples)
return sample
def sample_generic_multivariate(n_samples: int, n_variates: int, dist: multi_rv_generic,
*distparams: Any):
sample = dist(*distparams).rvs(n_samples)
if sample.shape[1] != n_variates:
msg = 'sample dimension %s inconsistent with n_variates=%s'
raise ValueError(msg % (sample.shape[1], n_variates))
return sample
n_samples = 4
n_variates = 2
frozen_dist = multivariate_normal(mean=[0, 0], cov=[[1, 0], [0, 1]])
print(sample_frozen_multivariate(n_samples, n_variates, frozen_dist))
n_samples = 4
n_variates = 2
generic_dist = multivariate_normal
mean, cov = [-1, 1], [[1, 0], [0, 1]]
print(sample_generic_multivariate(n_samples, n_variates, generic_dist, mean, cov))
However, mypy reports the following issues:
example_2.py:8: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "dim" [attr-defined]
example_2.py:10: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "dim" [attr-defined]
example_2.py:11: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "rvs" [attr-defined]
example_2.py:17: error: "multi_rv_generic" not callable [operator]
I understand the issues that mypy is reporting, just not how to properly address them. What approach(es) to type hints for generic and frozen multivariate distributions from scipy.stats
would appease type-checkers? (As an aside, I am aware of the problem with relying on types from "private" modules in scipy.stats
. Though I understand that is a separate issue, I welcome solutions that simultaneously address that problem.)
I need to differentiate between 4 different sets of probability distributions from the scipy.stats
module: generic univariate, frozen univariate, generic multivariate, & frozen multivariate. Throughout the application, I would like to add type hints for these 4 sets.
For the univariate cases, mypy has no problems with type hints like in this MWE:
from typing import Any
from scipy.stats import norm
from scipy.stats._distn_infrastructure import rv_generic, rv_frozen
def sample_frozen_univariate(n_samples: int, dist: rv_frozen):
return dist.rvs(n_samples)
def sample_generic_univariate(n_samples: int, dist: rv_generic,
*distparams: Any):
return dist(*distparams).rvs(n_samples)
n_samples = 4
frozen_dist = norm()
print(sample_frozen_univariate(n_samples, frozen_dist))
n_samples = 4
generic_dist = norm
loc, scale = 1, 2
print(sample_generic_univariate(n_samples, generic_dist, loc, scale))
Now here's an MWE of the multivariate cases. If executed it provides the correct result:
from typing import Any
from scipy.stats import multivariate_normal
from scipy.stats._multivariate import multi_rv_generic, multi_rv_frozen
def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: multi_rv_frozen):
if dist.dim != n_variates:
msg = 'distribution dimension %s inconsistent with n_variates=%s'
raise ValueError(msg % (dist.dim, n_variates))
sample = dist.rvs(n_samples)
return sample
def sample_generic_multivariate(n_samples: int, n_variates: int, dist: multi_rv_generic,
*distparams: Any):
sample = dist(*distparams).rvs(n_samples)
if sample.shape[1] != n_variates:
msg = 'sample dimension %s inconsistent with n_variates=%s'
raise ValueError(msg % (sample.shape[1], n_variates))
return sample
n_samples = 4
n_variates = 2
frozen_dist = multivariate_normal(mean=[0, 0], cov=[[1, 0], [0, 1]])
print(sample_frozen_multivariate(n_samples, n_variates, frozen_dist))
n_samples = 4
n_variates = 2
generic_dist = multivariate_normal
mean, cov = [-1, 1], [[1, 0], [0, 1]]
print(sample_generic_multivariate(n_samples, n_variates, generic_dist, mean, cov))
However, mypy reports the following issues:
example_2.py:8: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "dim" [attr-defined]
example_2.py:10: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "dim" [attr-defined]
example_2.py:11: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "rvs" [attr-defined]
example_2.py:17: error: "multi_rv_generic" not callable [operator]
I understand the issues that mypy is reporting, just not how to properly address them. What approach(es) to type hints for generic and frozen multivariate distributions from scipy.stats
would appease type-checkers? (As an aside, I am aware of the problem with relying on types from "private" modules in scipy.stats
. Though I understand that is a separate issue, I welcome solutions that simultaneously address that problem.)
- Do you only want to use multivariate_normal? Because not all frozen/generic offer all possibilities. – Daraan Commented yesterday
1 Answer
Reset to default 1multi_rv_generic
and multi_rv_frozen
are (minimal) basses for random variables; which do not extend to the interface that you need.
Scipy possibly lacks a more common interface as not all distributions have the same characteristics, this is why you cannot just use these two in this case.
If you only want to use multivariate normal its kind of easy:
from scipy.stats._multivariate import (
multivariate_normal_frozen,
multivariate_normal_gen,
)
def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: multivariate_normal_frozen):
if dist.dim != n_variates:
msg = "distribution dimension %s inconsistent with n_variates=%s"
raise ValueError(msg % (dist.dim, n_variates))
sample = dist.rvs(n_samples)
return sample
def sample_generic_multivariate(n_samples: int, n_variates: int, dist: multivariate_normal_gen, *distparams: Any):
sample = dist(*distparams).rvs(n_samples)
if sample.shape[1] != n_variates:
msg = "sample dimension %s inconsistent with n_variates=%s"
raise ValueError(msg % (sample.shape[1], n_variates))
return sample
Extend to all multivariates
As said not distributions provide the same functionality, two examples: dirichlet_multinomial and normal_inverse_gamma
Cover all existing distributions
To be on the safe side (as long as no new distribution is added) you can try to cover all types explicitly. You will see quite a few errors on this code which shows which distributions your current code cannot handle yet. So you either have to exclude certain distributions or need to update your code.
from scipy.stats import multivariate_normal
from scipy.stats._multivariate import (
multivariate_normal_frozen,
matrix_normal_frozen,
dirichlet_frozen,
wishart_frozen,
invwishart_frozen,
multinomial_frozen,
special_ortho_group_frozen,
ortho_group_frozen,
random_correlation_frozen,
unitary_group_frozen,
multivariate_t_frozen,
multivariate_hypergeom_frozen,
random_table_frozen,
uniform_direction_frozen,
dirichlet_multinomial_frozen,
vonmises_fisher_frozen,
normal_inverse_gamma_frozen,
multivariate_normal_gen,
matrix_normal_gen,
dirichlet_gen,
wishart_gen,
multinomial_gen,
special_ortho_group_gen,
ortho_group_gen,
random_correlation_gen,
unitary_group_gen,
multivariate_t_gen,
multivariate_hypergeom_gen,
random_table_gen,
uniform_direction_gen,
dirichlet_multinomial_gen,
vonmises_fisher_gen,
normal_inverse_gamma_gen,
)
# Use TypeAlias or TypeAliasType for Python < 3.12
type FrozenDistType = (
multivariate_normal_frozen
| matrix_normal_frozen
| dirichlet_frozen
| wishart_frozen
| invwishart_frozen
| multinomial_frozen
| special_ortho_group_frozen
| ortho_group_frozen
| random_correlation_frozen
| unitary_group_frozen
| multivariate_t_frozen
| multivariate_hypergeom_frozen
| random_table_frozen
| uniform_direction_frozen
| dirichlet_multinomial_frozen
| vonmises_fisher_frozen
| normal_inverse_gamma_frozen
)
type GenericDistType = (
multivariate_normal_gen
| matrix_normal_gen
| dirichlet_gen
| wishart_gen
| multinomial_gen
| special_ortho_group_gen
| ortho_group_gen
| random_correlation_gen
| unitary_group_gen
| multivariate_t_gen
| multivariate_hypergeom_gen
| random_table_gen
| uniform_direction_gen
| dirichlet_multinomial_gen
| vonmises_fisher_gen
| normal_inverse_gamma_gen
)
def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: FrozenDistType):
if dist.dim != n_variates:
msg = "distribution dimension %s inconsistent with n_variates=%s"
raise ValueError(msg % (dist.dim, n_variates))
sample = dist.rvs(n_samples)
return sample
def sample_generic_multivariate(n_samples: int, n_variates: int, dist: GenericDistType, *distparams: Any):
sample = dist(*distparams).rvs(n_samples)
if sample.shape[1] != n_variates:
msg = "sample dimension %s inconsistent with n_variates=%s"
raise ValueError(msg % (sample.shape[1], n_variates))
return sample
Cover only distributions you currently can cover (Type-checkers are happy; does not support all distributions)
Instead of covering all the distributions which might not cover your use case you can limit them with a Protocol
to only the cases you want to support. Only distributions that satisfy your current code are supported.
class FrozenDistProtocol(Protocol):
def rvs(self, size: int, *args, **kwargs) -> np.ndarray:
"""
Returns
-------
rvs : ndarray or scalar
"""
...
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.dim: int
class GenericDistProtocol(Protocol):
def rvs(self, *args: Any, **kwargs: Any) -> np.ndarray | Any:
"""
Returns
-------
rvs : ndarray or scalar
"""
...
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: FrozenDistProtocol):
if dist.dim != n_variates:
msg = "distribution dimension %s inconsistent with n_variates=%s"
raise ValueError(msg % (dist.dim, n_variates))
sample = dist.rvs(n_samples)
return sample
def sample_generic_multivariate(n_samples: int, n_variates: int, dist: GenericDistProtocol, *distparams: Any):
sample = dist(*distparams).rvs(n_samples)
if sample.shape[1] != n_variates:
msg = "sample dimension %s inconsistent with n_variates=%s"
raise ValueError(msg % (sample.shape[1], n_variates))
return sample
You should test with some other distributions of the signatures of the Protocol classes I provided are sufficient.