cbrkit.reuse.build
1from collections.abc import Sequence 2from dataclasses import dataclass 3from inspect import signature as inspect_signature 4from multiprocessing.pool import Pool 5from typing import cast, override 6 7from ..helpers import get_logger, mp_starmap, produce_factory 8from ..typing import ( 9 AdaptationFunc, 10 AnyAdaptationFunc, 11 Casebase, 12 Float, 13 MapAdaptationFunc, 14 MaybeFactory, 15 ReduceAdaptationFunc, 16 RetrieverFunc, 17 ReuserFunc, 18 SimMap, 19) 20 21logger = get_logger(__name__) 22 23 24@dataclass(slots=True, frozen=True) 25class build[K, V, S: Float](ReuserFunc[K, V, S]): 26 """Builds a casebase by adapting cases using an adaptation function and a similarity function. 27 28 Args: 29 adaptation_func: The adaptation function that will be applied to the cases. 30 retriever_func: The similarity function that will be used to compare the adapted cases to the query. 31 processes: The number of processes that will be used to apply the adaptation function. If processes is set to 1, the adaptation function will be applied in the main process. 32 33 Returns: 34 The adapted casebases and the similarities between the adapted cases and the query. 35 """ 36 37 adaptation_func: MaybeFactory[AnyAdaptationFunc[K, V]] 38 retriever_func: RetrieverFunc[K, V, S] 39 multiprocessing: Pool | int | bool = False 40 41 @override 42 def __call__( 43 self, 44 batches: Sequence[tuple[Casebase[K, V], V]], 45 ) -> Sequence[tuple[Casebase[K, V], SimMap[K, S]]]: 46 adaptation_func = produce_factory(self.adaptation_func) 47 adapted_casebases = self._adapt(batches, adaptation_func) 48 adapted_batches = [ 49 (adapted_casebase, query) 50 for adapted_casebase, (_, query) in zip( 51 adapted_casebases, batches, strict=True 52 ) 53 ] 54 adapted_similarities = self.retriever_func(adapted_batches) 55 56 return [ 57 (adapted_casebase, adapted_sim) 58 for adapted_casebase, adapted_sim in zip( 59 adapted_casebases, adapted_similarities, strict=True 60 ) 61 ] 62 63 def _adapt( 64 self, 65 batches: Sequence[tuple[Casebase[K, V], V]], 66 adaptation_func: AnyAdaptationFunc[K, V], 67 ) -> Sequence[Casebase[K, V]]: 68 adaptation_func_signature = inspect_signature(adaptation_func) 69 70 if "casebase" in adaptation_func_signature.parameters: 71 adapt_func = cast( 72 MapAdaptationFunc[K, V] | ReduceAdaptationFunc[K, V], 73 adaptation_func, 74 ) 75 adaptation_results = mp_starmap( 76 adapt_func, 77 batches, 78 self.multiprocessing, 79 logger, 80 ) 81 82 if all(isinstance(item, tuple) for item in adaptation_results): 83 adaptation_results = cast(Sequence[tuple[K, V]], adaptation_results) 84 return [ 85 {adapted_key: adapted_case} 86 for adapted_key, adapted_case in adaptation_results 87 ] 88 89 return cast(Sequence[Casebase[K, V]], adaptation_results) 90 91 adapt_func = cast(AdaptationFunc[V], adaptation_func) 92 batches_index: list[tuple[int, K]] = [] 93 flat_batches: list[tuple[V, V]] = [] 94 95 for idx, (casebase, query) in enumerate(batches): 96 for key, case in casebase.items(): 97 batches_index.append((idx, key)) 98 flat_batches.append((case, query)) 99 100 adapted_cases = mp_starmap( 101 adapt_func, 102 flat_batches, 103 self.multiprocessing, 104 logger, 105 ) 106 adapted_casebases: list[dict[K, V]] = [{} for _ in range(len(batches))] 107 108 for (idx, key), adapted_case in zip(batches_index, adapted_cases, strict=True): 109 adapted_casebases[idx][key] = adapted_case 110 111 return adapted_casebases
logger =
<Logger cbrkit.reuse.build (WARNING)>
@dataclass(slots=True, frozen=True)
class
build25@dataclass(slots=True, frozen=True) 26class build[K, V, S: Float](ReuserFunc[K, V, S]): 27 """Builds a casebase by adapting cases using an adaptation function and a similarity function. 28 29 Args: 30 adaptation_func: The adaptation function that will be applied to the cases. 31 retriever_func: The similarity function that will be used to compare the adapted cases to the query. 32 processes: The number of processes that will be used to apply the adaptation function. If processes is set to 1, the adaptation function will be applied in the main process. 33 34 Returns: 35 The adapted casebases and the similarities between the adapted cases and the query. 36 """ 37 38 adaptation_func: MaybeFactory[AnyAdaptationFunc[K, V]] 39 retriever_func: RetrieverFunc[K, V, S] 40 multiprocessing: Pool | int | bool = False 41 42 @override 43 def __call__( 44 self, 45 batches: Sequence[tuple[Casebase[K, V], V]], 46 ) -> Sequence[tuple[Casebase[K, V], SimMap[K, S]]]: 47 adaptation_func = produce_factory(self.adaptation_func) 48 adapted_casebases = self._adapt(batches, adaptation_func) 49 adapted_batches = [ 50 (adapted_casebase, query) 51 for adapted_casebase, (_, query) in zip( 52 adapted_casebases, batches, strict=True 53 ) 54 ] 55 adapted_similarities = self.retriever_func(adapted_batches) 56 57 return [ 58 (adapted_casebase, adapted_sim) 59 for adapted_casebase, adapted_sim in zip( 60 adapted_casebases, adapted_similarities, strict=True 61 ) 62 ] 63 64 def _adapt( 65 self, 66 batches: Sequence[tuple[Casebase[K, V], V]], 67 adaptation_func: AnyAdaptationFunc[K, V], 68 ) -> Sequence[Casebase[K, V]]: 69 adaptation_func_signature = inspect_signature(adaptation_func) 70 71 if "casebase" in adaptation_func_signature.parameters: 72 adapt_func = cast( 73 MapAdaptationFunc[K, V] | ReduceAdaptationFunc[K, V], 74 adaptation_func, 75 ) 76 adaptation_results = mp_starmap( 77 adapt_func, 78 batches, 79 self.multiprocessing, 80 logger, 81 ) 82 83 if all(isinstance(item, tuple) for item in adaptation_results): 84 adaptation_results = cast(Sequence[tuple[K, V]], adaptation_results) 85 return [ 86 {adapted_key: adapted_case} 87 for adapted_key, adapted_case in adaptation_results 88 ] 89 90 return cast(Sequence[Casebase[K, V]], adaptation_results) 91 92 adapt_func = cast(AdaptationFunc[V], adaptation_func) 93 batches_index: list[tuple[int, K]] = [] 94 flat_batches: list[tuple[V, V]] = [] 95 96 for idx, (casebase, query) in enumerate(batches): 97 for key, case in casebase.items(): 98 batches_index.append((idx, key)) 99 flat_batches.append((case, query)) 100 101 adapted_cases = mp_starmap( 102 adapt_func, 103 flat_batches, 104 self.multiprocessing, 105 logger, 106 ) 107 adapted_casebases: list[dict[K, V]] = [{} for _ in range(len(batches))] 108 109 for (idx, key), adapted_case in zip(batches_index, adapted_cases, strict=True): 110 adapted_casebases[idx][key] = adapted_case 111 112 return adapted_casebases
Builds a casebase by adapting cases using an adaptation function and a similarity function.
Arguments:
- adaptation_func: The adaptation function that will be applied to the cases.
- retriever_func: The similarity function that will be used to compare the adapted cases to the query.
- processes: The number of processes that will be used to apply the adaptation function. If processes is set to 1, the adaptation function will be applied in the main process.
Returns:
The adapted casebases and the similarities between the adapted cases and the query.
build( adaptation_func: MaybeFactory[AnyAdaptationFunc[K, V]], retriever_func: cbrkit.typing.RetrieverFunc[K, V, S], multiprocessing: multiprocessing.pool.Pool | int | bool = False)
retriever_func: cbrkit.typing.RetrieverFunc[K, V, S]