最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Appropriate type-hints for GenericFrozen Multivariate Distributions in scipy.stats - Stack Overflow

programmeradmin5浏览0评论

I need to differentiate between 4 different sets of probability distributions from the scipy.stats module: generic univariate, frozen univariate, generic multivariate, & frozen multivariate. Throughout the application, I would like to add type hints for these 4 sets.

For the univariate cases, mypy has no problems with type hints like in this MWE:

from typing import Any

from scipy.stats import norm
from scipy.stats._distn_infrastructure import rv_generic, rv_frozen


def sample_frozen_univariate(n_samples: int, dist: rv_frozen):
    return dist.rvs(n_samples)


def sample_generic_univariate(n_samples: int, dist: rv_generic,
                              *distparams: Any):
    return dist(*distparams).rvs(n_samples)


n_samples = 4
frozen_dist = norm()
print(sample_frozen_univariate(n_samples, frozen_dist))

n_samples = 4
generic_dist = norm
loc, scale = 1, 2
print(sample_generic_univariate(n_samples, generic_dist, loc, scale))

Now here's an MWE of the multivariate cases. If executed it provides the correct result:

from typing import Any

from scipy.stats import multivariate_normal
from scipy.stats._multivariate import multi_rv_generic, multi_rv_frozen


def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: multi_rv_frozen):
    if dist.dim != n_variates:
        msg = 'distribution dimension %s inconsistent with n_variates=%s'
        raise ValueError(msg % (dist.dim, n_variates))
    sample = dist.rvs(n_samples)
    return sample


def sample_generic_multivariate(n_samples: int, n_variates: int, dist: multi_rv_generic,
                                *distparams: Any):
    sample = dist(*distparams).rvs(n_samples)
    if sample.shape[1] != n_variates:
        msg = 'sample dimension %s inconsistent with n_variates=%s'
        raise ValueError(msg % (sample.shape[1], n_variates))
    return sample


n_samples = 4
n_variates = 2
frozen_dist = multivariate_normal(mean=[0, 0], cov=[[1, 0], [0, 1]])
print(sample_frozen_multivariate(n_samples, n_variates, frozen_dist))

n_samples = 4
n_variates = 2
generic_dist = multivariate_normal
mean, cov = [-1, 1], [[1, 0], [0, 1]]
print(sample_generic_multivariate(n_samples, n_variates, generic_dist, mean, cov))

However, mypy reports the following issues:

example_2.py:8: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "dim"  [attr-defined]
example_2.py:10: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "dim"  [attr-defined]
example_2.py:11: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "rvs"  [attr-defined]
example_2.py:17: error: "multi_rv_generic" not callable  [operator]

I understand the issues that mypy is reporting, just not how to properly address them. What approach(es) to type hints for generic and frozen multivariate distributions from scipy.stats would appease type-checkers? (As an aside, I am aware of the problem with relying on types from "private" modules in scipy.stats. Though I understand that is a separate issue, I welcome solutions that simultaneously address that problem.)

I need to differentiate between 4 different sets of probability distributions from the scipy.stats module: generic univariate, frozen univariate, generic multivariate, & frozen multivariate. Throughout the application, I would like to add type hints for these 4 sets.

For the univariate cases, mypy has no problems with type hints like in this MWE:

from typing import Any

from scipy.stats import norm
from scipy.stats._distn_infrastructure import rv_generic, rv_frozen


def sample_frozen_univariate(n_samples: int, dist: rv_frozen):
    return dist.rvs(n_samples)


def sample_generic_univariate(n_samples: int, dist: rv_generic,
                              *distparams: Any):
    return dist(*distparams).rvs(n_samples)


n_samples = 4
frozen_dist = norm()
print(sample_frozen_univariate(n_samples, frozen_dist))

n_samples = 4
generic_dist = norm
loc, scale = 1, 2
print(sample_generic_univariate(n_samples, generic_dist, loc, scale))

Now here's an MWE of the multivariate cases. If executed it provides the correct result:

from typing import Any

from scipy.stats import multivariate_normal
from scipy.stats._multivariate import multi_rv_generic, multi_rv_frozen


def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: multi_rv_frozen):
    if dist.dim != n_variates:
        msg = 'distribution dimension %s inconsistent with n_variates=%s'
        raise ValueError(msg % (dist.dim, n_variates))
    sample = dist.rvs(n_samples)
    return sample


def sample_generic_multivariate(n_samples: int, n_variates: int, dist: multi_rv_generic,
                                *distparams: Any):
    sample = dist(*distparams).rvs(n_samples)
    if sample.shape[1] != n_variates:
        msg = 'sample dimension %s inconsistent with n_variates=%s'
        raise ValueError(msg % (sample.shape[1], n_variates))
    return sample


n_samples = 4
n_variates = 2
frozen_dist = multivariate_normal(mean=[0, 0], cov=[[1, 0], [0, 1]])
print(sample_frozen_multivariate(n_samples, n_variates, frozen_dist))

n_samples = 4
n_variates = 2
generic_dist = multivariate_normal
mean, cov = [-1, 1], [[1, 0], [0, 1]]
print(sample_generic_multivariate(n_samples, n_variates, generic_dist, mean, cov))

However, mypy reports the following issues:

example_2.py:8: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "dim"  [attr-defined]
example_2.py:10: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "dim"  [attr-defined]
example_2.py:11: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "rvs"  [attr-defined]
example_2.py:17: error: "multi_rv_generic" not callable  [operator]

I understand the issues that mypy is reporting, just not how to properly address them. What approach(es) to type hints for generic and frozen multivariate distributions from scipy.stats would appease type-checkers? (As an aside, I am aware of the problem with relying on types from "private" modules in scipy.stats. Though I understand that is a separate issue, I welcome solutions that simultaneously address that problem.)

Share Improve this question edited yesterday Daraan 3,8657 gold badges22 silver badges46 bronze badges asked 2 days ago brentertainerbrentertainer 2,1981 gold badge7 silver badges16 bronze badges 1
  • Do you only want to use multivariate_normal? Because not all frozen/generic offer all possibilities. – Daraan Commented yesterday
Add a comment  | 

1 Answer 1

Reset to default 1

multi_rv_generic and multi_rv_frozen are (minimal) basses for random variables; which do not extend to the interface that you need. Scipy possibly lacks a more common interface as not all distributions have the same characteristics, this is why you cannot just use these two in this case.

If you only want to use multivariate normal its kind of easy:

from scipy.stats._multivariate import (
    multivariate_normal_frozen,
    multivariate_normal_gen,
)

def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: multivariate_normal_frozen):
    if dist.dim != n_variates:
        msg = "distribution dimension %s inconsistent with n_variates=%s"
        raise ValueError(msg % (dist.dim, n_variates))
    sample = dist.rvs(n_samples)
    return sample


def sample_generic_multivariate(n_samples: int, n_variates: int, dist: multivariate_normal_gen, *distparams: Any):
    sample = dist(*distparams).rvs(n_samples)
    if sample.shape[1] != n_variates:
        msg = "sample dimension %s inconsistent with n_variates=%s"
        raise ValueError(msg % (sample.shape[1], n_variates))
    return sample

Extend to all multivariates

As said not distributions provide the same functionality, two examples: dirichlet_multinomial and normal_inverse_gamma

Cover all existing distributions

To be on the safe side (as long as no new distribution is added) you can try to cover all types explicitly. You will see quite a few errors on this code which shows which distributions your current code cannot handle yet. So you either have to exclude certain distributions or need to update your code.

from scipy.stats import multivariate_normal
from scipy.stats._multivariate import (
    multivariate_normal_frozen,
    matrix_normal_frozen,
    dirichlet_frozen,
    wishart_frozen,
    invwishart_frozen,
    multinomial_frozen,
    special_ortho_group_frozen,
    ortho_group_frozen,
    random_correlation_frozen,
    unitary_group_frozen,
    multivariate_t_frozen,
    multivariate_hypergeom_frozen,
    random_table_frozen,
    uniform_direction_frozen,
    dirichlet_multinomial_frozen,
    vonmises_fisher_frozen,
    normal_inverse_gamma_frozen,
    
    multivariate_normal_gen,
    matrix_normal_gen,
    dirichlet_gen,
    wishart_gen,
    multinomial_gen,
    special_ortho_group_gen,
    ortho_group_gen,
    random_correlation_gen,
    unitary_group_gen,
    multivariate_t_gen,
    multivariate_hypergeom_gen,
    random_table_gen,
    uniform_direction_gen,
    dirichlet_multinomial_gen,
    vonmises_fisher_gen,
    normal_inverse_gamma_gen,
)

# Use TypeAlias or TypeAliasType for Python < 3.12
type FrozenDistType = (
    multivariate_normal_frozen
    | matrix_normal_frozen
    | dirichlet_frozen
    | wishart_frozen
    | invwishart_frozen
    | multinomial_frozen
    | special_ortho_group_frozen
    | ortho_group_frozen
    | random_correlation_frozen
    | unitary_group_frozen
    | multivariate_t_frozen
    | multivariate_hypergeom_frozen
    | random_table_frozen
    | uniform_direction_frozen
    | dirichlet_multinomial_frozen
    | vonmises_fisher_frozen
    | normal_inverse_gamma_frozen
)

type GenericDistType = (
    multivariate_normal_gen
    | matrix_normal_gen
    | dirichlet_gen
    | wishart_gen
    | multinomial_gen
    | special_ortho_group_gen
    | ortho_group_gen
    | random_correlation_gen
    | unitary_group_gen
    | multivariate_t_gen
    | multivariate_hypergeom_gen
    | random_table_gen
    | uniform_direction_gen
    | dirichlet_multinomial_gen
    | vonmises_fisher_gen
    | normal_inverse_gamma_gen
)

def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: FrozenDistType):
    if dist.dim != n_variates:
        msg = "distribution dimension %s inconsistent with n_variates=%s"
        raise ValueError(msg % (dist.dim, n_variates))
    sample = dist.rvs(n_samples)
    return sample


def sample_generic_multivariate(n_samples: int, n_variates: int, dist: GenericDistType, *distparams: Any):
    sample = dist(*distparams).rvs(n_samples)
    if sample.shape[1] != n_variates:
        msg = "sample dimension %s inconsistent with n_variates=%s"
        raise ValueError(msg % (sample.shape[1], n_variates))
    return sample
Cover only distributions you currently can cover (Type-checkers are happy; does not support all distributions)

Instead of covering all the distributions which might not cover your use case you can limit them with a Protocol to only the cases you want to support. Only distributions that satisfy your current code are supported.

class FrozenDistProtocol(Protocol):
    def rvs(self, size: int, *args, **kwargs) -> np.ndarray:
        """
        Returns
        -------
        rvs : ndarray or scalar
        """
        ...

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        self.dim: int


class GenericDistProtocol(Protocol):
    def rvs(self, *args: Any, **kwargs: Any) -> np.ndarray | Any:
        """
        Returns
        -------
        rvs : ndarray or scalar
        """
        ...

    def __init__(self, *args: Any, **kwargs: Any) -> None: ...

def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: FrozenDistProtocol):
    if dist.dim != n_variates:
        msg = "distribution dimension %s inconsistent with n_variates=%s"
        raise ValueError(msg % (dist.dim, n_variates))
    sample = dist.rvs(n_samples)
    return sample


def sample_generic_multivariate(n_samples: int, n_variates: int, dist: GenericDistProtocol, *distparams: Any):
    sample = dist(*distparams).rvs(n_samples)
    if sample.shape[1] != n_variates:
        msg = "sample dimension %s inconsistent with n_variates=%s"
        raise ValueError(msg % (sample.shape[1], n_variates))
    return sample

You should test with some other distributions of the signatures of the Protocol classes I provided are sufficient.

发布评论

评论列表(0)

  1. 暂无评论