cbrkit.sim
CBRkit contains a selection of similarity measures for different data types.
Besides measures for standard data types like
numbers (cbrkit.sim.numbers
),
strings (cbrkit.sim.strings
),
lists/collections (cbrkit.sim.collections
),
and generic data (cbrkit.sim.generic
),
there is also a measure for attribute-value data.
Additionally, the module contains an aggregator to combine multiple local measures into a global score.
1""" 2CBRkit contains a selection of similarity measures for different data types. 3Besides measures for standard data types like 4numbers (`cbrkit.sim.numbers`), 5strings (`cbrkit.sim.strings`), 6lists/collections (`cbrkit.sim.collections`), 7and generic data (`cbrkit.sim.generic`), 8there is also a measure for attribute-value data. 9Additionally, the module contains an aggregator to combine multiple local measures into a global score. 10""" 11 12from . import collections, embed, generic, graphs, numbers, pooling, strings, taxonomy 13from .aggregator import aggregator 14from .pooling import PoolingName 15from .attribute_value import AttributeValueSim, attribute_value 16from .wrappers import ( 17 attribute_table, 18 cache, 19 combine, 20 dynamic_table, 21 table, 22 transpose, 23 transpose_value, 24 type_table, 25) 26 27__all__ = [ 28 "transpose", 29 "transpose_value", 30 "cache", 31 "combine", 32 "table", 33 "dynamic_table", 34 "type_table", 35 "attribute_table", 36 "collections", 37 "generic", 38 "numbers", 39 "strings", 40 "attribute_value", 41 "graphs", 42 "embed", 43 "taxonomy", 44 "pooling", 45 "aggregator", 46 "PoolingName", 47 "AttributeValueSim", 48]
23@dataclass(slots=True) 24class transpose[V1, V2, S: Float](BatchSimFunc[V1, S]): 25 """Transforms a similarity function from one type to another. 26 27 Args: 28 similarity_func: The similarity function to be used on the converted values. 29 conversion_func: A function that converts the input values from one type to another. 30 31 Examples: 32 >>> from cbrkit.sim.generic import equality 33 >>> sim = transpose( 34 ... similarity_func=equality(), 35 ... conversion_func=lambda x: x.lower(), 36 ... ) 37 >>> sim([("A", "a"), ("b", "B")]) 38 [1.0, 1.0] 39 """ 40 41 similarity_func: BatchSimFunc[V2, S] 42 conversion_func: ConversionFunc[V1, V2] 43 44 def __init__( 45 self, 46 similarity_func: AnySimFunc[V2, S], 47 conversion_func: ConversionFunc[V1, V2], 48 ): 49 self.similarity_func = batchify_sim(similarity_func) 50 self.conversion_func = conversion_func 51 52 @override 53 def __call__(self, batches: Sequence[tuple[V1, V1]]) -> Sequence[S]: 54 return self.similarity_func( 55 [(self.conversion_func(x), self.conversion_func(y)) for x, y in batches] 56 )
Transforms a similarity function from one type to another.
Arguments:
- similarity_func: The similarity function to be used on the converted values.
- conversion_func: A function that converts the input values from one type to another.
Examples:
>>> from cbrkit.sim.generic import equality >>> sim = transpose( ... similarity_func=equality(), ... conversion_func=lambda x: x.lower(), ... ) >>> sim([("A", "a"), ("b", "B")]) [1.0, 1.0]
126@dataclass(slots=True) 127class cache[V, U, S: Float](BatchSimFunc[V, S]): 128 similarity_func: BatchSimFunc[V, S] 129 conversion_func: ConversionFunc[V, U] | None 130 store: MutableMapping[tuple[U, U], S] = field(repr=False) 131 132 def __init__( 133 self, 134 similarity_func: AnySimFunc[V, S], 135 conversion_func: ConversionFunc[V, U] | None = None, 136 ): 137 self.similarity_func = batchify_sim(similarity_func) 138 self.conversion_func = conversion_func 139 self.store = {} 140 141 @override 142 def __call__(self, batches: Sequence[tuple[V, V]]) -> SimSeq[S]: 143 transformed_batches = ( 144 [(self.conversion_func(x), self.conversion_func(y)) for x, y in batches] 145 if self.conversion_func is not None 146 else cast(list[tuple[U, U]], batches) 147 ) 148 uncached_indexes = [ 149 idx 150 for idx, pair in enumerate(transformed_batches) 151 if pair not in self.store 152 ] 153 154 uncached_sims = self.similarity_func([batches[idx] for idx in uncached_indexes]) 155 self.store.update( 156 { 157 transformed_batches[idx]: sim 158 for idx, sim in zip(uncached_indexes, uncached_sims, strict=True) 159 } 160 ) 161 162 return [self.store[pair] for pair in transformed_batches]
65@dataclass(slots=True) 66class combine[V, S: Float](BatchSimFunc[V, float]): 67 """Combines multiple similarity functions into one. 68 69 Args: 70 sim_funcs: A list of similarity functions to be combined. 71 aggregator: A function to aggregate the results from the similarity functions. 72 73 Returns: 74 A similarity function that combines the results from multiple similarity functions. 75 """ 76 77 sim_funcs: InitVar[Sequence[AnySimFunc[V, S]] | Mapping[str, AnySimFunc[V, S]]] 78 aggregator: AggregatorFunc[str, S] = default_aggregator 79 batch_sim_funcs: Sequence[BatchSimFunc[V, S]] | Mapping[str, BatchSimFunc[V, S]] = ( 80 field(init=False, repr=False) 81 ) 82 83 def __post_init__( 84 self, sim_funcs: Sequence[AnySimFunc[V, S]] | Mapping[str, AnySimFunc[V, S]] 85 ): 86 if isinstance(sim_funcs, Sequence): 87 self.batch_sim_funcs = [batchify_sim(func) for func in sim_funcs] 88 elif isinstance(sim_funcs, Mapping): 89 self.batch_sim_funcs = { 90 key: batchify_sim(func) for key, func in sim_funcs.items() 91 } 92 else: 93 raise ValueError(f"Invalid sim_funcs type: {type(sim_funcs)}") 94 95 @override 96 def __call__(self, batches: Sequence[tuple[V, V]]) -> Sequence[float]: 97 if isinstance(self.batch_sim_funcs, Sequence): 98 func_results = [func(batches) for func in self.batch_sim_funcs] 99 100 return [ 101 self.aggregator( 102 [batch_results[batch_idx] for batch_results in func_results] 103 ) 104 for batch_idx in range(len(batches)) 105 ] 106 107 elif isinstance(self.batch_sim_funcs, Mapping): 108 func_results = { 109 func_key: func(batches) 110 for func_key, func in self.batch_sim_funcs.items() 111 } 112 113 return [ 114 self.aggregator( 115 { 116 func_key: batch_results[batch_idx] 117 for func_key, batch_results in func_results.items() 118 } 119 ) 120 for batch_idx in range(len(batches)) 121 ] 122 123 raise ValueError(f"Invalid batch_sim_funcs type: {type(self.batch_sim_funcs)}")
Combines multiple similarity functions into one.
Arguments:
- sim_funcs: A list of similarity functions to be combined.
- aggregator: A function to aggregate the results from the similarity functions.
Returns:
A similarity function that combines the results from multiple similarity functions.
165@dataclass(slots=True) 166class dynamic_table[K, U, V, S: Float](BatchSimFunc[U | V, S], HasMetadata): 167 """Allows to import a similarity values from a table. 168 169 Args: 170 entries: Sequence[tuple[a, b, sim(a, b)] 171 symmetric: If True, the table is assumed to be symmetric, i.e. sim(a, b) = sim(b, a) 172 default: Default similarity value for pairs not in the table 173 key_getter: A function that extracts the the key for lookup from the input values 174 175 Examples: 176 >>> from cbrkit.helpers import identity 177 >>> from cbrkit.sim.generic import static 178 >>> sim = dynamic_table( 179 ... { 180 ... ("a", "b"): static(0.5), 181 ... ("b", "c"): static(0.7) 182 ... }, 183 ... symmetric=True, 184 ... default=static(0.0), 185 ... key_getter=identity, 186 ... ) 187 >>> sim([("b", "a"), ("a", "c")]) 188 [0.5, 0.0] 189 """ 190 191 symmetric: bool 192 default: BatchSimFunc[U, S] | None 193 key_getter: Callable[[Any], K] 194 table: dict[tuple[K, K], BatchSimFunc[V, S]] 195 196 @property 197 @override 198 def metadata(self) -> JsonDict: 199 return { 200 "symmetric": self.symmetric, 201 "default": get_metadata(self.default), 202 "key_getter": get_metadata(self.key_getter), 203 "table": [ 204 { 205 "x": str(k[0]), 206 "y": str(k[1]), 207 "value": get_metadata(v), 208 } 209 for k, v in self.table.items() 210 ], 211 } 212 213 def __init__( 214 self, 215 entries: Mapping[tuple[K, K], AnySimFunc[..., S]] 216 | Mapping[K, AnySimFunc[..., S]], 217 key_getter: Callable[[Any], K], 218 default: AnySimFunc[U, S] | S | None = None, 219 symmetric: bool = True, 220 ): 221 self.symmetric = symmetric 222 self.key_getter = key_getter 223 self.table = {} 224 225 if isinstance(default, Callable): 226 self.default = batchify_sim(default) 227 elif default is None: 228 self.default = None 229 else: 230 self.default = batchify_sim(static(default)) 231 232 for key, val in entries.items(): 233 func = batchify_sim(val) 234 235 if isinstance(key, tuple): 236 x, y = cast(tuple[K, K], key) 237 else: 238 x = y = cast(K, key) 239 240 self.table[(x, y)] = func 241 242 if self.symmetric and x != y: 243 self.table[(y, x)] = func 244 245 @override 246 def __call__(self, batches: Sequence[tuple[U | V, U | V]]) -> SimSeq[S]: 247 # then we group the batches by key to avoid redundant computations 248 idx_map: defaultdict[tuple[K, K] | None, list[int]] = defaultdict(list) 249 250 for idx, (x, y) in enumerate(batches): 251 key = (self.key_getter(x), self.key_getter(y)) 252 253 if key in self.table: 254 idx_map[key].append(idx) 255 else: 256 idx_map[None].append(idx) 257 258 # now we compute the similarities 259 results: dict[int, S] = {} 260 261 for key, idxs in idx_map.items(): 262 sim_func = cast( 263 BatchSimFunc[U | V, S] | None, 264 self.table.get(key) if key is not None else self.default, 265 ) 266 267 if sim_func is None: 268 missing_entries = [batches[idx] for idx in idxs] 269 missing_keys = { 270 (self.key_getter(x), self.key_getter(y)) for x, y in missing_entries 271 } 272 273 raise ValueError(f"Pairs {missing_keys} not in the table") 274 275 sims = sim_func([batches[idx] for idx in idxs]) 276 277 for idx, sim in zip(idxs, sims, strict=True): 278 results[idx] = sim 279 280 return [results[idx] for idx in range(len(batches))]
Allows to import a similarity values from a table.
Arguments:
- entries: Sequence[tuple[a, b, sim(a, b)]
- symmetric: If True, the table is assumed to be symmetric, i.e. sim(a, b) = sim(b, a)
- default: Default similarity value for pairs not in the table
- key_getter: A function that extracts the the key for lookup from the input values
Examples:
>>> from cbrkit.helpers import identity >>> from cbrkit.sim.generic import static >>> sim = dynamic_table( ... { ... ("a", "b"): static(0.5), ... ("b", "c"): static(0.7) ... }, ... symmetric=True, ... default=static(0.0), ... key_getter=identity, ... ) >>> sim([("b", "a"), ("a", "c")]) [0.5, 0.0]
213 def __init__( 214 self, 215 entries: Mapping[tuple[K, K], AnySimFunc[..., S]] 216 | Mapping[K, AnySimFunc[..., S]], 217 key_getter: Callable[[Any], K], 218 default: AnySimFunc[U, S] | S | None = None, 219 symmetric: bool = True, 220 ): 221 self.symmetric = symmetric 222 self.key_getter = key_getter 223 self.table = {} 224 225 if isinstance(default, Callable): 226 self.default = batchify_sim(default) 227 elif default is None: 228 self.default = None 229 else: 230 self.default = batchify_sim(static(default)) 231 232 for key, val in entries.items(): 233 func = batchify_sim(val) 234 235 if isinstance(key, tuple): 236 x, y = cast(tuple[K, K], key) 237 else: 238 x = y = cast(K, key) 239 240 self.table[(x, y)] = func 241 242 if self.symmetric and x != y: 243 self.table[(y, x)] = func
Inherited Members
307def attribute_table[K, U, S: Float]( 308 entries: Mapping[K, AnySimFunc[..., S]], 309 attribute: str, 310 default: AnySimFunc[U, S] | S | None = None, 311 value_getter: Callable[[Any, str], K] = getitem_or_getattr, 312) -> BatchSimFunc[Any, S]: 313 key_getter = attribute_table_key_getter(value_getter, attribute) 314 315 return dynamic_table( 316 entries=entries, 317 key_getter=key_getter, 318 default=default, 319 symmetric=False, 320 )
27@dataclass(slots=True, frozen=True) 28class attribute_value[V, S: Float](BatchSimFunc[V, AttributeValueSim[S]]): 29 """Similarity function that computes the attribute value similarity between two cases. 30 31 Args: 32 attributes: A mapping of attribute names to the similarity functions to be used for those attributes. 33 aggregator: A function that aggregates the local similarity scores for each attribute into a single global similarity. 34 value_getter: A function that retrieves the value of an attribute from a case. 35 default: The default similarity score to use when an error occurs during the computation of a similarity score. 36 For example, if a case does not have an attribute that is required for the similarity computation. 37 38 Examples: 39 >>> equality = lambda x, y: 1.0 if x == y else 0.0 40 >>> sim = attribute_value({ 41 ... "name": equality, 42 ... "age": equality, 43 ... }) 44 >>> scores = sim([ 45 ... ({"name": "John", "age": 25}, {"name": "John", "age": 30}), 46 ... ({"name": "Jane", "age": 30}, {"name": "John", "age": 30}), 47 ... ]) 48 >>> scores[0] 49 AttributeValueSim(value=0.5, attributes={'name': 1.0, 'age': 0.0}) 50 >>> scores[1] 51 AttributeValueSim(value=0.5, attributes={'name': 0.0, 'age': 1.0}) 52 """ 53 54 attributes: Mapping[str, AnySimFunc[Any, S]] 55 aggregator: AggregatorFunc[str, S] = default_aggregator 56 value_getter: Callable[[Any, str], Any] = getitem_or_getattr 57 default: S | None = None 58 59 @override 60 def __call__(self, batches: Sequence[tuple[V, V]]) -> SimSeq[AttributeValueSim[S]]: 61 if len(batches) == 0: 62 return [] 63 64 local_sims: list[dict[str, S]] = [dict() for _ in range(len(batches))] 65 66 for attr_name in self.attributes: 67 logger.debug(f"Processing attribute {attr_name}") 68 69 try: 70 attribute_values = [ 71 (self.value_getter(x, attr_name), self.value_getter(y, attr_name)) 72 for x, y in batches 73 ] 74 sim_func = batchify_sim(self.attributes[attr_name]) 75 sim_func_result = sim_func(attribute_values) 76 77 for idx, sim in enumerate(sim_func_result): 78 local_sims[idx][attr_name] = sim 79 80 except Exception as e: 81 if self.default is not None: 82 for idx in range(len(batches)): 83 local_sims[idx][attr_name] = self.default 84 else: 85 raise e 86 87 return [AttributeValueSim(self.aggregator(sims), sims) for sims in local_sims]
Similarity function that computes the attribute value similarity between two cases.
Arguments:
- attributes: A mapping of attribute names to the similarity functions to be used for those attributes.
- aggregator: A function that aggregates the local similarity scores for each attribute into a single global similarity.
- value_getter: A function that retrieves the value of an attribute from a case.
- default: The default similarity score to use when an error occurs during the computation of a similarity score. For example, if a case does not have an attribute that is required for the similarity computation.
Examples:
>>> equality = lambda x, y: 1.0 if x == y else 0.0 >>> sim = attribute_value({ ... "name": equality, ... "age": equality, ... }) >>> scores = sim([ ... ({"name": "John", "age": 25}, {"name": "John", "age": 30}), ... ({"name": "Jane", "age": 30}, {"name": "John", "age": 30}), ... ]) >>> scores[0] AttributeValueSim(value=0.5, attributes={'name': 1.0, 'age': 0.0}) >>> scores[1] AttributeValueSim(value=0.5, attributes={'name': 0.0, 'age': 1.0})
22@dataclass(slots=True, frozen=True) 23class aggregator[K](AggregatorFunc[K, Float]): 24 """ 25 Aggregates local similarities to a global similarity using the specified pooling function. 26 27 Args: 28 pooling: The pooling function to use. It can be either a string representing the name of the pooling function or a custom pooling function (see `cbrkit.typing.PoolingFunc`). 29 pooling_weights: The weights to apply to the similarities during pooling. It can be a sequence or a mapping. If None, every local similarity is weighted equally. 30 default_pooling_weight: The default weight to use if a similarity key is not found in the pooling_weights mapping. 31 32 Examples: 33 >>> agg = aggregator("mean") 34 >>> agg([0.5, 0.75, 1.0]) 35 0.75 36 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 0}) 37 >>> agg({1: 1, 2: 1, 3: 1}) 38 1.0 39 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 2}) 40 >>> agg({1: 1, 2: 1, 3: 1}) 41 1.0 42 """ 43 44 pooling: PoolingName | PoolingFunc[float] = "mean" 45 pooling_weights: SimMap[K, float] | SimSeq[float] | None = None 46 default_pooling_weight: float = 1.0 47 48 @override 49 def __call__(self, similarities: SimMap[K, Float] | SimSeq[Float]) -> float: 50 pooling_func = ( 51 pooling_funcs[self.pooling] 52 if isinstance(self.pooling, str) 53 else self.pooling 54 ) 55 assert (self.pooling_weights is None) or ( 56 type(similarities) is type(self.pooling_weights) # noqa: E721 57 ) 58 59 pooling_factor = 1.0 60 sims: Sequence[float] # noqa: F821 61 62 if isinstance(similarities, Mapping) and isinstance( 63 self.pooling_weights, Mapping 64 ): 65 sims = [ 66 unpack_float(sim) 67 * self.pooling_weights.get(key, self.default_pooling_weight) 68 for key, sim in similarities.items() 69 ] 70 pooling_factor = len(similarities) / sum( 71 self.pooling_weights.get(key, self.default_pooling_weight) 72 for key in similarities.keys() 73 ) 74 elif isinstance(similarities, Sequence) and isinstance( 75 self.pooling_weights, Sequence 76 ): 77 sims = [ 78 unpack_float(s) * w 79 for s, w in zip(similarities, self.pooling_weights, strict=True) 80 ] 81 pooling_factor = len(similarities) / sum(self.pooling_weights) 82 elif isinstance(similarities, Sequence) and self.pooling_weights is None: 83 sims = [unpack_float(s) for s in similarities] 84 elif isinstance(similarities, Mapping) and self.pooling_weights is None: 85 sims = [unpack_float(s) for s in similarities.values()] 86 else: 87 raise NotImplementedError() 88 89 return pooling_func(sims) * pooling_factor
Aggregates local similarities to a global similarity using the specified pooling function.
Arguments:
- pooling: The pooling function to use. It can be either a string representing the name of the pooling function or a custom pooling function (see
cbrkit.typing.PoolingFunc
). - pooling_weights: The weights to apply to the similarities during pooling. It can be a sequence or a mapping. If None, every local similarity is weighted equally.
- default_pooling_weight: The default weight to use if a similarity key is not found in the pooling_weights mapping.
Examples:
>>> agg = aggregator("mean") >>> agg([0.5, 0.75, 1.0]) 0.75 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 0}) >>> agg({1: 1, 2: 1, 3: 1}) 1.0 >>> agg = aggregator("mean", {1: 1, 2: 1, 3: 2}) >>> agg({1: 1, 2: 1, 3: 1}) 1.0
22@dataclass(slots=True, frozen=True) 23class AttributeValueSim[S: Float](StructuredValue[float]): 24 attributes: Mapping[str, S]