cbrkit.sim.aggregator
1from collections.abc import Mapping, Sequence 2from dataclasses import dataclass 3from typing import cast, override 4 5from ..helpers import unpack_float 6from ..typing import ( 7 AggregatorFunc, 8 Float, 9 PoolingFunc, 10 SimMap, 11 SimSeq, 12) 13from .pooling import PoolingName, pooling_funcs 14 15__all__ = [ 16 "default_aggregator", 17 "aggregator", 18] 19 20 21@dataclass(slots=True, frozen=True) 22class aggregator[K](AggregatorFunc[K, Float]): 23 """ 24 Aggregates local similarities to a global similarity using the specified pooling function. 25 26 Args: 27 pooling: The pooling function to use. It can be either a string representing the name of the pooling function or a custom pooling function (see `cbrkit.typing.PoolingFunc`). 28 pooling_weights: The weights to apply to the similarities during pooling. It can be a sequence or a mapping. If None, every local similarity is weighted equally. 29 default_pooling_weight: The default weight to use if a similarity key is not found in the pooling_weights mapping. 30 31 Examples: 32 >>> agg = aggregator("mean") 33 >>> agg([0.5, 0.75, 1.0]) 34 0.75 35 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 0}) 36 >>> agg({1: 1, 2: 1, 3: 1}) 37 1.0 38 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 2}) 39 >>> agg({1: 1, 2: 1, 3: 1}) 40 1.0 41 """ 42 43 pooling: PoolingName | PoolingFunc[float] = "mean" 44 pooling_weights: SimMap[K, float] | SimSeq[float] | None = None 45 default_pooling_weight: float = 1.0 46 47 @override 48 def __call__(self, similarities: SimMap[K, Float] | SimSeq[Float]) -> float: 49 pooling_func = ( 50 pooling_funcs[cast(PoolingName, self.pooling)] 51 if isinstance(self.pooling, str) 52 else self.pooling 53 ) 54 assert (self.pooling_weights is None) or ( 55 type(similarities) is type(self.pooling_weights) # noqa: E721 56 ) 57 58 pooling_factor = 1.0 59 sims: Sequence[float] # noqa: F821 60 61 if isinstance(similarities, Mapping) and isinstance( 62 self.pooling_weights, Mapping 63 ): 64 sim_map = cast(SimMap[K, Float], similarities) 65 weight_map = cast(SimMap[K, float], self.pooling_weights) 66 sims = [ 67 unpack_float(sim) * weight_map.get(key, self.default_pooling_weight) 68 for key, sim in sim_map.items() 69 ] 70 pooling_factor = len(sim_map) / sum( 71 weight_map.get(key, self.default_pooling_weight) 72 for key in sim_map.keys() 73 ) 74 elif isinstance(similarities, Sequence) and isinstance( 75 self.pooling_weights, Sequence 76 ): 77 sims = [ 78 unpack_float(s) * w 79 for s, w in zip(similarities, self.pooling_weights, strict=True) 80 ] 81 pooling_factor = len(similarities) / sum(self.pooling_weights) 82 elif isinstance(similarities, Sequence) and self.pooling_weights is None: 83 sim_seq = cast(SimSeq[Float], similarities) 84 sims = [unpack_float(s) for s in sim_seq] 85 elif isinstance(similarities, Mapping) and self.pooling_weights is None: 86 sim_map = cast(SimMap[K, Float], similarities) 87 sims = [unpack_float(s) for s in sim_map.values()] 88 else: 89 raise NotImplementedError() 90 91 return pooling_func(sims) * pooling_factor 92 93 94default_aggregator = aggregator()
default_aggregator =
aggregator(pooling='mean', pooling_weights=None, default_pooling_weight=1.0)
@dataclass(slots=True, frozen=True)
class
aggregator22@dataclass(slots=True, frozen=True) 23class aggregator[K](AggregatorFunc[K, Float]): 24 """ 25 Aggregates local similarities to a global similarity using the specified pooling function. 26 27 Args: 28 pooling: The pooling function to use. It can be either a string representing the name of the pooling function or a custom pooling function (see `cbrkit.typing.PoolingFunc`). 29 pooling_weights: The weights to apply to the similarities during pooling. It can be a sequence or a mapping. If None, every local similarity is weighted equally. 30 default_pooling_weight: The default weight to use if a similarity key is not found in the pooling_weights mapping. 31 32 Examples: 33 >>> agg = aggregator("mean") 34 >>> agg([0.5, 0.75, 1.0]) 35 0.75 36 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 0}) 37 >>> agg({1: 1, 2: 1, 3: 1}) 38 1.0 39 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 2}) 40 >>> agg({1: 1, 2: 1, 3: 1}) 41 1.0 42 """ 43 44 pooling: PoolingName | PoolingFunc[float] = "mean" 45 pooling_weights: SimMap[K, float] | SimSeq[float] | None = None 46 default_pooling_weight: float = 1.0 47 48 @override 49 def __call__(self, similarities: SimMap[K, Float] | SimSeq[Float]) -> float: 50 pooling_func = ( 51 pooling_funcs[cast(PoolingName, self.pooling)] 52 if isinstance(self.pooling, str) 53 else self.pooling 54 ) 55 assert (self.pooling_weights is None) or ( 56 type(similarities) is type(self.pooling_weights) # noqa: E721 57 ) 58 59 pooling_factor = 1.0 60 sims: Sequence[float] # noqa: F821 61 62 if isinstance(similarities, Mapping) and isinstance( 63 self.pooling_weights, Mapping 64 ): 65 sim_map = cast(SimMap[K, Float], similarities) 66 weight_map = cast(SimMap[K, float], self.pooling_weights) 67 sims = [ 68 unpack_float(sim) * weight_map.get(key, self.default_pooling_weight) 69 for key, sim in sim_map.items() 70 ] 71 pooling_factor = len(sim_map) / sum( 72 weight_map.get(key, self.default_pooling_weight) 73 for key in sim_map.keys() 74 ) 75 elif isinstance(similarities, Sequence) and isinstance( 76 self.pooling_weights, Sequence 77 ): 78 sims = [ 79 unpack_float(s) * w 80 for s, w in zip(similarities, self.pooling_weights, strict=True) 81 ] 82 pooling_factor = len(similarities) / sum(self.pooling_weights) 83 elif isinstance(similarities, Sequence) and self.pooling_weights is None: 84 sim_seq = cast(SimSeq[Float], similarities) 85 sims = [unpack_float(s) for s in sim_seq] 86 elif isinstance(similarities, Mapping) and self.pooling_weights is None: 87 sim_map = cast(SimMap[K, Float], similarities) 88 sims = [unpack_float(s) for s in sim_map.values()] 89 else: 90 raise NotImplementedError() 91 92 return pooling_func(sims) * pooling_factor
Aggregates local similarities to a global similarity using the specified pooling function.
Arguments:
- pooling: The pooling function to use. It can be either a string representing the name of the pooling function or a custom pooling function (see
cbrkit.typing.PoolingFunc). - pooling_weights: The weights to apply to the similarities during pooling. It can be a sequence or a mapping. If None, every local similarity is weighted equally.
- default_pooling_weight: The default weight to use if a similarity key is not found in the pooling_weights mapping.
Examples:
>>> agg = aggregator("mean") >>> agg([0.5, 0.75, 1.0]) 0.75 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 0}) >>> agg({1: 1, 2: 1, 3: 1}) 1.0 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 2}) >>> agg({1: 1, 2: 1, 3: 1}) 1.0
aggregator( pooling: Union[PoolingName, cbrkit.typing.PoolingFunc[float]] = 'mean', pooling_weights: SimMap[K, float] | SimSeq[float] | None = None, default_pooling_weight: float = 1.0)
pooling: Union[PoolingName, cbrkit.typing.PoolingFunc[float]]