cbrkit.reuse.build
1import itertools 2from collections.abc import Sequence 3from dataclasses import dataclass 4from inspect import signature as inspect_signature 5from multiprocessing.pool import Pool 6from typing import cast, override 7 8from ..helpers import ( 9 batchify_adaptation, 10 batchify_sim, 11 chunkify, 12 get_logger, 13 mp_count, 14 mp_map, 15 mp_starmap, 16 produce_factory, 17 use_mp, 18) 19from ..typing import ( 20 AnyAdaptationFunc, 21 AnySimFunc, 22 Casebase, 23 Float, 24 MapAdaptationFunc, 25 MaybeFactory, 26 ReduceAdaptationFunc, 27 ReuserFunc, 28 SimMap, 29 SimpleAdaptationFunc, 30) 31 32logger = get_logger(__name__) 33 34__all__ = ["build"] 35 36 37@dataclass(slots=True, frozen=True) 38class build[K, V, S: Float](ReuserFunc[K, V, S]): 39 """Builds a casebase by adapting cases using an adaptation function and a similarity function. 40 41 Args: 42 adaptation_func: The adaptation function that will be applied to the cases. 43 similarity_func: The similarity function that will be used to compare the adapted cases to the query. 44 multiprocessing: Multiprocessing configuration for adaptation. 45 chunksize: Number of batches to process at a time using the adaptation function. 46 If 0, it will be set to the number of batches divided by the number of processes. 47 48 Returns: 49 The adapted casebases and the similarities between the adapted cases and the query. 50 """ 51 52 adaptation_func: MaybeFactory[AnyAdaptationFunc[K, V]] 53 similarity_func: MaybeFactory[AnySimFunc[V, S]] 54 multiprocessing: Pool | int | bool = False 55 chunksize: int = 0 56 57 @override 58 def __call__( 59 self, 60 batches: Sequence[tuple[Casebase[K, V], V]], 61 ) -> Sequence[tuple[Casebase[K, V], SimMap[K, S]]]: 62 adaptation_func = produce_factory(self.adaptation_func) 63 adapted_casebases = self._adapt(batches, adaptation_func) 64 adapted_batches = [ 65 (adapted_casebase, query) 66 for adapted_casebase, (_, query) in zip( 67 adapted_casebases, batches, strict=True 68 ) 69 ] 70 71 # Score adapted cases against queries 72 sim_func = batchify_sim(produce_factory(self.similarity_func)) 73 74 flat_batches: list[tuple[V, V]] = [] 75 flat_index: list[tuple[int, K]] = [] 76 77 for idx, (casebase, query) in enumerate(adapted_batches): 78 for key, case in casebase.items(): 79 flat_index.append((idx, key)) 80 flat_batches.append((case, query)) 81 82 scores = sim_func(flat_batches) 83 84 sim_maps: list[dict[K, S]] = [{} for _ in adapted_batches] 85 for (idx, key), score in zip(flat_index, scores, strict=True): 86 sim_maps[idx][key] = score 87 88 return [ 89 (adapted_casebase, sim_map) 90 for adapted_casebase, sim_map in zip( 91 adapted_casebases, sim_maps, strict=True 92 ) 93 ] 94 95 def _adapt( 96 self, 97 batches: Sequence[tuple[Casebase[K, V], V]], 98 adaptation_func: AnyAdaptationFunc[K, V], 99 ) -> Sequence[Casebase[K, V]]: 100 adaptation_func_signature = inspect_signature(adaptation_func) 101 102 if "casebase" in adaptation_func_signature.parameters: 103 adapt_func = cast( 104 MapAdaptationFunc[K, V] | ReduceAdaptationFunc[K, V], 105 adaptation_func, 106 ) 107 adaptation_results = mp_starmap( 108 adapt_func, 109 batches, 110 self.multiprocessing, 111 logger, 112 ) 113 114 if all(isinstance(item, tuple) for item in adaptation_results): 115 adaptation_results = cast(Sequence[tuple[K, V]], adaptation_results) 116 return [ 117 {adapted_key: adapted_case} 118 for adapted_key, adapted_case in adaptation_results 119 ] 120 121 return cast(Sequence[Casebase[K, V]], adaptation_results) 122 123 adapt_func = batchify_adaptation(cast(SimpleAdaptationFunc[V], adaptation_func)) 124 batches_index: list[tuple[int, K]] = [] 125 flat_batches: list[tuple[V, V]] = [] 126 127 for idx, (casebase, query) in enumerate(batches): 128 for key, case in casebase.items(): 129 batches_index.append((idx, key)) 130 flat_batches.append((case, query)) 131 132 adapted_cases: Sequence[V] 133 134 if use_mp(self.multiprocessing) or self.chunksize > 0: 135 chunksize = ( 136 self.chunksize 137 if self.chunksize > 0 138 else len(flat_batches) // mp_count(self.multiprocessing) 139 ) 140 batch_chunks = list(chunkify(flat_batches, chunksize)) 141 adapted_chunks = mp_map( 142 adapt_func, batch_chunks, self.multiprocessing, logger 143 ) 144 adapted_cases = list(itertools.chain.from_iterable(adapted_chunks)) 145 else: 146 adapted_cases = list(adapt_func(flat_batches)) 147 148 adapted_casebases: list[dict[K, V]] = [{} for _ in range(len(batches))] 149 150 for (idx, key), adapted_case in zip(batches_index, adapted_cases, strict=True): 151 adapted_casebases[idx][key] = adapted_case 152 153 return adapted_casebases
@dataclass(slots=True, frozen=True)
class
build38@dataclass(slots=True, frozen=True) 39class build[K, V, S: Float](ReuserFunc[K, V, S]): 40 """Builds a casebase by adapting cases using an adaptation function and a similarity function. 41 42 Args: 43 adaptation_func: The adaptation function that will be applied to the cases. 44 similarity_func: The similarity function that will be used to compare the adapted cases to the query. 45 multiprocessing: Multiprocessing configuration for adaptation. 46 chunksize: Number of batches to process at a time using the adaptation function. 47 If 0, it will be set to the number of batches divided by the number of processes. 48 49 Returns: 50 The adapted casebases and the similarities between the adapted cases and the query. 51 """ 52 53 adaptation_func: MaybeFactory[AnyAdaptationFunc[K, V]] 54 similarity_func: MaybeFactory[AnySimFunc[V, S]] 55 multiprocessing: Pool | int | bool = False 56 chunksize: int = 0 57 58 @override 59 def __call__( 60 self, 61 batches: Sequence[tuple[Casebase[K, V], V]], 62 ) -> Sequence[tuple[Casebase[K, V], SimMap[K, S]]]: 63 adaptation_func = produce_factory(self.adaptation_func) 64 adapted_casebases = self._adapt(batches, adaptation_func) 65 adapted_batches = [ 66 (adapted_casebase, query) 67 for adapted_casebase, (_, query) in zip( 68 adapted_casebases, batches, strict=True 69 ) 70 ] 71 72 # Score adapted cases against queries 73 sim_func = batchify_sim(produce_factory(self.similarity_func)) 74 75 flat_batches: list[tuple[V, V]] = [] 76 flat_index: list[tuple[int, K]] = [] 77 78 for idx, (casebase, query) in enumerate(adapted_batches): 79 for key, case in casebase.items(): 80 flat_index.append((idx, key)) 81 flat_batches.append((case, query)) 82 83 scores = sim_func(flat_batches) 84 85 sim_maps: list[dict[K, S]] = [{} for _ in adapted_batches] 86 for (idx, key), score in zip(flat_index, scores, strict=True): 87 sim_maps[idx][key] = score 88 89 return [ 90 (adapted_casebase, sim_map) 91 for adapted_casebase, sim_map in zip( 92 adapted_casebases, sim_maps, strict=True 93 ) 94 ] 95 96 def _adapt( 97 self, 98 batches: Sequence[tuple[Casebase[K, V], V]], 99 adaptation_func: AnyAdaptationFunc[K, V], 100 ) -> Sequence[Casebase[K, V]]: 101 adaptation_func_signature = inspect_signature(adaptation_func) 102 103 if "casebase" in adaptation_func_signature.parameters: 104 adapt_func = cast( 105 MapAdaptationFunc[K, V] | ReduceAdaptationFunc[K, V], 106 adaptation_func, 107 ) 108 adaptation_results = mp_starmap( 109 adapt_func, 110 batches, 111 self.multiprocessing, 112 logger, 113 ) 114 115 if all(isinstance(item, tuple) for item in adaptation_results): 116 adaptation_results = cast(Sequence[tuple[K, V]], adaptation_results) 117 return [ 118 {adapted_key: adapted_case} 119 for adapted_key, adapted_case in adaptation_results 120 ] 121 122 return cast(Sequence[Casebase[K, V]], adaptation_results) 123 124 adapt_func = batchify_adaptation(cast(SimpleAdaptationFunc[V], adaptation_func)) 125 batches_index: list[tuple[int, K]] = [] 126 flat_batches: list[tuple[V, V]] = [] 127 128 for idx, (casebase, query) in enumerate(batches): 129 for key, case in casebase.items(): 130 batches_index.append((idx, key)) 131 flat_batches.append((case, query)) 132 133 adapted_cases: Sequence[V] 134 135 if use_mp(self.multiprocessing) or self.chunksize > 0: 136 chunksize = ( 137 self.chunksize 138 if self.chunksize > 0 139 else len(flat_batches) // mp_count(self.multiprocessing) 140 ) 141 batch_chunks = list(chunkify(flat_batches, chunksize)) 142 adapted_chunks = mp_map( 143 adapt_func, batch_chunks, self.multiprocessing, logger 144 ) 145 adapted_cases = list(itertools.chain.from_iterable(adapted_chunks)) 146 else: 147 adapted_cases = list(adapt_func(flat_batches)) 148 149 adapted_casebases: list[dict[K, V]] = [{} for _ in range(len(batches))] 150 151 for (idx, key), adapted_case in zip(batches_index, adapted_cases, strict=True): 152 adapted_casebases[idx][key] = adapted_case 153 154 return adapted_casebases
Builds a casebase by adapting cases using an adaptation function and a similarity function.
Arguments:
- adaptation_func: The adaptation function that will be applied to the cases.
- similarity_func: The similarity function that will be used to compare the adapted cases to the query.
- multiprocessing: Multiprocessing configuration for adaptation.
- chunksize: Number of batches to process at a time using the adaptation function. If 0, it will be set to the number of batches divided by the number of processes.
Returns:
The adapted casebases and the similarities between the adapted cases and the query.