cbrkit.sim.aggregator

 1from collections.abc import Mapping, Sequence
 2from dataclasses import dataclass
 3from typing import override
 4
 5from ..helpers import unpack_float
 6from ..typing import (
 7    AggregatorFunc,
 8    Float,
 9    PoolingFunc,
10    SimMap,
11    SimSeq,
12)
13from .pooling import PoolingName, pooling_funcs
14
15__all__ = [
16    "default_aggregator",
17    "aggregator",
18]
19
20
21@dataclass(slots=True, frozen=True)
22class aggregator[K](AggregatorFunc[K, Float]):
23    """
24    Aggregates local similarities to a global similarity using the specified pooling function.
25
26    Args:
27        pooling: The pooling function to use. It can be either a string representing the name of the pooling function or a custom pooling function (see `cbrkit.typing.PoolingFunc`).
28        pooling_weights: The weights to apply to the similarities during pooling. It can be a sequence or a mapping. If None, every local similarity is weighted equally.
29        default_pooling_weight: The default weight to use if a similarity key is not found in the pooling_weights mapping.
30
31    Examples:
32        >>> agg = aggregator("mean")
33        >>> agg([0.5, 0.75, 1.0])
34        0.75
35        >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 0})
36        >>> agg({1: 1, 2: 1, 3: 1})
37        1.0
38        >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 2})
39        >>> agg({1: 1, 2: 1, 3: 1})
40        1.0
41    """
42
43    pooling: PoolingName | PoolingFunc[float] = "mean"
44    pooling_weights: SimMap[K, float] | SimSeq[float] | None = None
45    default_pooling_weight: float = 1.0
46
47    @override
48    def __call__(self, similarities: SimMap[K, Float] | SimSeq[Float]) -> float:
49        pooling_func = (
50            pooling_funcs[self.pooling]
51            if isinstance(self.pooling, str)
52            else self.pooling
53        )
54        assert (self.pooling_weights is None) or (
55            type(similarities) is type(self.pooling_weights)  # noqa: E721
56        )
57
58        pooling_factor = 1.0
59        sims: Sequence[float]  # noqa: F821
60
61        if isinstance(similarities, Mapping) and isinstance(
62            self.pooling_weights, Mapping
63        ):
64            sims = [
65                unpack_float(sim)
66                * self.pooling_weights.get(key, self.default_pooling_weight)
67                for key, sim in similarities.items()
68            ]
69            pooling_factor = len(similarities) / sum(
70                self.pooling_weights.get(key, self.default_pooling_weight)
71                for key in similarities.keys()
72            )
73        elif isinstance(similarities, Sequence) and isinstance(
74            self.pooling_weights, Sequence
75        ):
76            sims = [
77                unpack_float(s) * w
78                for s, w in zip(similarities, self.pooling_weights, strict=True)
79            ]
80            pooling_factor = len(similarities) / sum(self.pooling_weights)
81        elif isinstance(similarities, Sequence) and self.pooling_weights is None:
82            sims = [unpack_float(s) for s in similarities]
83        elif isinstance(similarities, Mapping) and self.pooling_weights is None:
84            sims = [unpack_float(s) for s in similarities.values()]
85        else:
86            raise NotImplementedError()
87
88        return pooling_func(sims) * pooling_factor
89
90
91default_aggregator = aggregator()
default_aggregator = aggregator(pooling='mean', pooling_weights=None, default_pooling_weight=1.0)
@dataclass(slots=True, frozen=True)
class aggregator(cbrkit.typing.AggregatorFunc[K, Float], typing.Generic[K]):
22@dataclass(slots=True, frozen=True)
23class aggregator[K](AggregatorFunc[K, Float]):
24    """
25    Aggregates local similarities to a global similarity using the specified pooling function.
26
27    Args:
28        pooling: The pooling function to use. It can be either a string representing the name of the pooling function or a custom pooling function (see `cbrkit.typing.PoolingFunc`).
29        pooling_weights: The weights to apply to the similarities during pooling. It can be a sequence or a mapping. If None, every local similarity is weighted equally.
30        default_pooling_weight: The default weight to use if a similarity key is not found in the pooling_weights mapping.
31
32    Examples:
33        >>> agg = aggregator("mean")
34        >>> agg([0.5, 0.75, 1.0])
35        0.75
36        >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 0})
37        >>> agg({1: 1, 2: 1, 3: 1})
38        1.0
39        >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 2})
40        >>> agg({1: 1, 2: 1, 3: 1})
41        1.0
42    """
43
44    pooling: PoolingName | PoolingFunc[float] = "mean"
45    pooling_weights: SimMap[K, float] | SimSeq[float] | None = None
46    default_pooling_weight: float = 1.0
47
48    @override
49    def __call__(self, similarities: SimMap[K, Float] | SimSeq[Float]) -> float:
50        pooling_func = (
51            pooling_funcs[self.pooling]
52            if isinstance(self.pooling, str)
53            else self.pooling
54        )
55        assert (self.pooling_weights is None) or (
56            type(similarities) is type(self.pooling_weights)  # noqa: E721
57        )
58
59        pooling_factor = 1.0
60        sims: Sequence[float]  # noqa: F821
61
62        if isinstance(similarities, Mapping) and isinstance(
63            self.pooling_weights, Mapping
64        ):
65            sims = [
66                unpack_float(sim)
67                * self.pooling_weights.get(key, self.default_pooling_weight)
68                for key, sim in similarities.items()
69            ]
70            pooling_factor = len(similarities) / sum(
71                self.pooling_weights.get(key, self.default_pooling_weight)
72                for key in similarities.keys()
73            )
74        elif isinstance(similarities, Sequence) and isinstance(
75            self.pooling_weights, Sequence
76        ):
77            sims = [
78                unpack_float(s) * w
79                for s, w in zip(similarities, self.pooling_weights, strict=True)
80            ]
81            pooling_factor = len(similarities) / sum(self.pooling_weights)
82        elif isinstance(similarities, Sequence) and self.pooling_weights is None:
83            sims = [unpack_float(s) for s in similarities]
84        elif isinstance(similarities, Mapping) and self.pooling_weights is None:
85            sims = [unpack_float(s) for s in similarities.values()]
86        else:
87            raise NotImplementedError()
88
89        return pooling_func(sims) * pooling_factor

Aggregates local similarities to a global similarity using the specified pooling function.

Arguments:
  • pooling: The pooling function to use. It can be either a string representing the name of the pooling function or a custom pooling function (see cbrkit.typing.PoolingFunc).
  • pooling_weights: The weights to apply to the similarities during pooling. It can be a sequence or a mapping. If None, every local similarity is weighted equally.
  • default_pooling_weight: The default weight to use if a similarity key is not found in the pooling_weights mapping.
Examples:
>>> agg = aggregator("mean")
>>> agg([0.5, 0.75, 1.0])
0.75
>>> agg = aggregator("mean", {1: 1, 2: 1, 3: 0})
>>> agg({1: 1, 2: 1, 3: 1})
1.0
>>> agg = aggregator("mean", {1: 1, 2: 1, 3: 2})
>>> agg({1: 1, 2: 1, 3: 1})
1.0
aggregator( pooling: Union[PoolingName, cbrkit.typing.PoolingFunc[float]] = 'mean', pooling_weights: SimMap[K, float] | SimSeq[float] | None = None, default_pooling_weight: float = 1.0)
pooling: Union[PoolingName, cbrkit.typing.PoolingFunc[float]]
pooling_weights: SimMap[K, float] | SimSeq[float] | None
default_pooling_weight: float