cbrkit.sim.aggregator
1from collections.abc import Mapping, Sequence 2from dataclasses import dataclass 3from typing import override 4 5from ..helpers import unpack_float 6from ..typing import ( 7 AggregatorFunc, 8 Float, 9 PoolingFunc, 10 SimMap, 11 SimSeq, 12) 13from .pooling import PoolingName, pooling_funcs 14 15__all__ = [ 16 "default_aggregator", 17 "aggregator", 18] 19 20 21@dataclass(slots=True, frozen=True) 22class aggregator[K](AggregatorFunc[K, Float]): 23 """ 24 Aggregates local similarities to a global similarity using the specified pooling function. 25 26 Args: 27 pooling: The pooling function to use. It can be either a string representing the name of the pooling function or a custom pooling function (see `cbrkit.typing.PoolingFunc`). 28 pooling_weights: The weights to apply to the similarities during pooling. It can be a sequence or a mapping. If None, every local similarity is weighted equally. 29 default_pooling_weight: The default weight to use if a similarity key is not found in the pooling_weights mapping. 30 31 Examples: 32 >>> agg = aggregator("mean") 33 >>> agg([0.5, 0.75, 1.0]) 34 0.75 35 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 0}) 36 >>> agg({1: 1, 2: 1, 3: 1}) 37 1.0 38 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 2}) 39 >>> agg({1: 1, 2: 1, 3: 1}) 40 1.0 41 """ 42 43 pooling: PoolingName | PoolingFunc[float] = "mean" 44 pooling_weights: SimMap[K, float] | SimSeq[float] | None = None 45 default_pooling_weight: float = 1.0 46 47 @override 48 def __call__(self, similarities: SimMap[K, Float] | SimSeq[Float]) -> float: 49 pooling_func = ( 50 pooling_funcs[self.pooling] 51 if isinstance(self.pooling, str) 52 else self.pooling 53 ) 54 assert (self.pooling_weights is None) or ( 55 type(similarities) is type(self.pooling_weights) # noqa: E721 56 ) 57 58 pooling_factor = 1.0 59 sims: Sequence[float] # noqa: F821 60 61 if isinstance(similarities, Mapping) and isinstance( 62 self.pooling_weights, Mapping 63 ): 64 sims = [ 65 unpack_float(sim) 66 * self.pooling_weights.get(key, self.default_pooling_weight) 67 for key, sim in similarities.items() 68 ] 69 pooling_factor = len(similarities) / sum( 70 self.pooling_weights.get(key, self.default_pooling_weight) 71 for key in similarities.keys() 72 ) 73 elif isinstance(similarities, Sequence) and isinstance( 74 self.pooling_weights, Sequence 75 ): 76 sims = [ 77 unpack_float(s) * w 78 for s, w in zip(similarities, self.pooling_weights, strict=True) 79 ] 80 pooling_factor = len(similarities) / sum(self.pooling_weights) 81 elif isinstance(similarities, Sequence) and self.pooling_weights is None: 82 sims = [unpack_float(s) for s in similarities] 83 elif isinstance(similarities, Mapping) and self.pooling_weights is None: 84 sims = [unpack_float(s) for s in similarities.values()] 85 else: 86 raise NotImplementedError() 87 88 return pooling_func(sims) * pooling_factor 89 90 91default_aggregator = aggregator()
default_aggregator =
aggregator(pooling='mean', pooling_weights=None, default_pooling_weight=1.0)
@dataclass(slots=True, frozen=True)
class
aggregator22@dataclass(slots=True, frozen=True) 23class aggregator[K](AggregatorFunc[K, Float]): 24 """ 25 Aggregates local similarities to a global similarity using the specified pooling function. 26 27 Args: 28 pooling: The pooling function to use. It can be either a string representing the name of the pooling function or a custom pooling function (see `cbrkit.typing.PoolingFunc`). 29 pooling_weights: The weights to apply to the similarities during pooling. It can be a sequence or a mapping. If None, every local similarity is weighted equally. 30 default_pooling_weight: The default weight to use if a similarity key is not found in the pooling_weights mapping. 31 32 Examples: 33 >>> agg = aggregator("mean") 34 >>> agg([0.5, 0.75, 1.0]) 35 0.75 36 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 0}) 37 >>> agg({1: 1, 2: 1, 3: 1}) 38 1.0 39 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 2}) 40 >>> agg({1: 1, 2: 1, 3: 1}) 41 1.0 42 """ 43 44 pooling: PoolingName | PoolingFunc[float] = "mean" 45 pooling_weights: SimMap[K, float] | SimSeq[float] | None = None 46 default_pooling_weight: float = 1.0 47 48 @override 49 def __call__(self, similarities: SimMap[K, Float] | SimSeq[Float]) -> float: 50 pooling_func = ( 51 pooling_funcs[self.pooling] 52 if isinstance(self.pooling, str) 53 else self.pooling 54 ) 55 assert (self.pooling_weights is None) or ( 56 type(similarities) is type(self.pooling_weights) # noqa: E721 57 ) 58 59 pooling_factor = 1.0 60 sims: Sequence[float] # noqa: F821 61 62 if isinstance(similarities, Mapping) and isinstance( 63 self.pooling_weights, Mapping 64 ): 65 sims = [ 66 unpack_float(sim) 67 * self.pooling_weights.get(key, self.default_pooling_weight) 68 for key, sim in similarities.items() 69 ] 70 pooling_factor = len(similarities) / sum( 71 self.pooling_weights.get(key, self.default_pooling_weight) 72 for key in similarities.keys() 73 ) 74 elif isinstance(similarities, Sequence) and isinstance( 75 self.pooling_weights, Sequence 76 ): 77 sims = [ 78 unpack_float(s) * w 79 for s, w in zip(similarities, self.pooling_weights, strict=True) 80 ] 81 pooling_factor = len(similarities) / sum(self.pooling_weights) 82 elif isinstance(similarities, Sequence) and self.pooling_weights is None: 83 sims = [unpack_float(s) for s in similarities] 84 elif isinstance(similarities, Mapping) and self.pooling_weights is None: 85 sims = [unpack_float(s) for s in similarities.values()] 86 else: 87 raise NotImplementedError() 88 89 return pooling_func(sims) * pooling_factor
Aggregates local similarities to a global similarity using the specified pooling function.
Arguments:
- pooling: The pooling function to use. It can be either a string representing the name of the pooling function or a custom pooling function (see
cbrkit.typing.PoolingFunc
). - pooling_weights: The weights to apply to the similarities during pooling. It can be a sequence or a mapping. If None, every local similarity is weighted equally.
- default_pooling_weight: The default weight to use if a similarity key is not found in the pooling_weights mapping.
Examples:
>>> agg = aggregator("mean") >>> agg([0.5, 0.75, 1.0]) 0.75 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 0}) >>> agg({1: 1, 2: 1, 3: 1}) 1.0 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 2}) >>> agg({1: 1, 2: 1, 3: 1}) 1.0
aggregator( pooling: Union[PoolingName, cbrkit.typing.PoolingFunc[float]] = 'mean', pooling_weights: SimMap[K, float] | SimSeq[float] | None = None, default_pooling_weight: float = 1.0)
pooling: Union[PoolingName, cbrkit.typing.PoolingFunc[float]]