cbrkit.reuse.build

  1from collections.abc import Sequence
  2from dataclasses import dataclass
  3from inspect import signature as inspect_signature
  4from multiprocessing.pool import Pool
  5from typing import cast, override
  6
  7from ..helpers import get_logger, mp_starmap, produce_factory
  8from ..typing import (
  9    AdaptationFunc,
 10    AnyAdaptationFunc,
 11    Casebase,
 12    Float,
 13    MapAdaptationFunc,
 14    MaybeFactory,
 15    ReduceAdaptationFunc,
 16    RetrieverFunc,
 17    ReuserFunc,
 18    SimMap,
 19)
 20
 21logger = get_logger(__name__)
 22
 23
 24@dataclass(slots=True, frozen=True)
 25class build[K, V, S: Float](ReuserFunc[K, V, S]):
 26    """Builds a casebase by adapting cases using an adaptation function and a similarity function.
 27
 28    Args:
 29        adaptation_func: The adaptation function that will be applied to the cases.
 30        retriever_func: The similarity function that will be used to compare the adapted cases to the query.
 31        processes: The number of processes that will be used to apply the adaptation function. If processes is set to 1, the adaptation function will be applied in the main process.
 32
 33    Returns:
 34        The adapted casebases and the similarities between the adapted cases and the query.
 35    """
 36
 37    adaptation_func: MaybeFactory[AnyAdaptationFunc[K, V]]
 38    retriever_func: RetrieverFunc[K, V, S]
 39    multiprocessing: Pool | int | bool = False
 40
 41    @override
 42    def __call__(
 43        self,
 44        batches: Sequence[tuple[Casebase[K, V], V]],
 45    ) -> Sequence[tuple[Casebase[K, V], SimMap[K, S]]]:
 46        adaptation_func = produce_factory(self.adaptation_func)
 47        adapted_casebases = self._adapt(batches, adaptation_func)
 48        adapted_batches = [
 49            (adapted_casebase, query)
 50            for adapted_casebase, (_, query) in zip(
 51                adapted_casebases, batches, strict=True
 52            )
 53        ]
 54        adapted_similarities = self.retriever_func(adapted_batches)
 55
 56        return [
 57            (adapted_casebase, adapted_sim)
 58            for adapted_casebase, adapted_sim in zip(
 59                adapted_casebases, adapted_similarities, strict=True
 60            )
 61        ]
 62
 63    def _adapt(
 64        self,
 65        batches: Sequence[tuple[Casebase[K, V], V]],
 66        adaptation_func: AnyAdaptationFunc[K, V],
 67    ) -> Sequence[Casebase[K, V]]:
 68        adaptation_func_signature = inspect_signature(adaptation_func)
 69
 70        if "casebase" in adaptation_func_signature.parameters:
 71            adapt_func = cast(
 72                MapAdaptationFunc[K, V] | ReduceAdaptationFunc[K, V],
 73                adaptation_func,
 74            )
 75            adaptation_results = mp_starmap(
 76                adapt_func,
 77                batches,
 78                self.multiprocessing,
 79                logger,
 80            )
 81
 82            if all(isinstance(item, tuple) for item in adaptation_results):
 83                adaptation_results = cast(Sequence[tuple[K, V]], adaptation_results)
 84                return [
 85                    {adapted_key: adapted_case}
 86                    for adapted_key, adapted_case in adaptation_results
 87                ]
 88
 89            return cast(Sequence[Casebase[K, V]], adaptation_results)
 90
 91        adapt_func = cast(AdaptationFunc[V], adaptation_func)
 92        batches_index: list[tuple[int, K]] = []
 93        flat_batches: list[tuple[V, V]] = []
 94
 95        for idx, (casebase, query) in enumerate(batches):
 96            for key, case in casebase.items():
 97                batches_index.append((idx, key))
 98                flat_batches.append((case, query))
 99
100        adapted_cases = mp_starmap(
101            adapt_func,
102            flat_batches,
103            self.multiprocessing,
104            logger,
105        )
106        adapted_casebases: list[dict[K, V]] = [{} for _ in range(len(batches))]
107
108        for (idx, key), adapted_case in zip(batches_index, adapted_cases, strict=True):
109            adapted_casebases[idx][key] = adapted_case
110
111        return adapted_casebases
logger = <Logger cbrkit.reuse.build (WARNING)>
@dataclass(slots=True, frozen=True)
class build(cbrkit.typing.ReuserFunc[K, V, S], typing.Generic[K, V, S]):
 25@dataclass(slots=True, frozen=True)
 26class build[K, V, S: Float](ReuserFunc[K, V, S]):
 27    """Builds a casebase by adapting cases using an adaptation function and a similarity function.
 28
 29    Args:
 30        adaptation_func: The adaptation function that will be applied to the cases.
 31        retriever_func: The similarity function that will be used to compare the adapted cases to the query.
 32        processes: The number of processes that will be used to apply the adaptation function. If processes is set to 1, the adaptation function will be applied in the main process.
 33
 34    Returns:
 35        The adapted casebases and the similarities between the adapted cases and the query.
 36    """
 37
 38    adaptation_func: MaybeFactory[AnyAdaptationFunc[K, V]]
 39    retriever_func: RetrieverFunc[K, V, S]
 40    multiprocessing: Pool | int | bool = False
 41
 42    @override
 43    def __call__(
 44        self,
 45        batches: Sequence[tuple[Casebase[K, V], V]],
 46    ) -> Sequence[tuple[Casebase[K, V], SimMap[K, S]]]:
 47        adaptation_func = produce_factory(self.adaptation_func)
 48        adapted_casebases = self._adapt(batches, adaptation_func)
 49        adapted_batches = [
 50            (adapted_casebase, query)
 51            for adapted_casebase, (_, query) in zip(
 52                adapted_casebases, batches, strict=True
 53            )
 54        ]
 55        adapted_similarities = self.retriever_func(adapted_batches)
 56
 57        return [
 58            (adapted_casebase, adapted_sim)
 59            for adapted_casebase, adapted_sim in zip(
 60                adapted_casebases, adapted_similarities, strict=True
 61            )
 62        ]
 63
 64    def _adapt(
 65        self,
 66        batches: Sequence[tuple[Casebase[K, V], V]],
 67        adaptation_func: AnyAdaptationFunc[K, V],
 68    ) -> Sequence[Casebase[K, V]]:
 69        adaptation_func_signature = inspect_signature(adaptation_func)
 70
 71        if "casebase" in adaptation_func_signature.parameters:
 72            adapt_func = cast(
 73                MapAdaptationFunc[K, V] | ReduceAdaptationFunc[K, V],
 74                adaptation_func,
 75            )
 76            adaptation_results = mp_starmap(
 77                adapt_func,
 78                batches,
 79                self.multiprocessing,
 80                logger,
 81            )
 82
 83            if all(isinstance(item, tuple) for item in adaptation_results):
 84                adaptation_results = cast(Sequence[tuple[K, V]], adaptation_results)
 85                return [
 86                    {adapted_key: adapted_case}
 87                    for adapted_key, adapted_case in adaptation_results
 88                ]
 89
 90            return cast(Sequence[Casebase[K, V]], adaptation_results)
 91
 92        adapt_func = cast(AdaptationFunc[V], adaptation_func)
 93        batches_index: list[tuple[int, K]] = []
 94        flat_batches: list[tuple[V, V]] = []
 95
 96        for idx, (casebase, query) in enumerate(batches):
 97            for key, case in casebase.items():
 98                batches_index.append((idx, key))
 99                flat_batches.append((case, query))
100
101        adapted_cases = mp_starmap(
102            adapt_func,
103            flat_batches,
104            self.multiprocessing,
105            logger,
106        )
107        adapted_casebases: list[dict[K, V]] = [{} for _ in range(len(batches))]
108
109        for (idx, key), adapted_case in zip(batches_index, adapted_cases, strict=True):
110            adapted_casebases[idx][key] = adapted_case
111
112        return adapted_casebases

Builds a casebase by adapting cases using an adaptation function and a similarity function.

Arguments:
  • adaptation_func: The adaptation function that will be applied to the cases.
  • retriever_func: The similarity function that will be used to compare the adapted cases to the query.
  • processes: The number of processes that will be used to apply the adaptation function. If processes is set to 1, the adaptation function will be applied in the main process.
Returns:

The adapted casebases and the similarities between the adapted cases and the query.

build( adaptation_func: MaybeFactory[AnyAdaptationFunc[K, V]], retriever_func: cbrkit.typing.RetrieverFunc[K, V, S], multiprocessing: multiprocessing.pool.Pool | int | bool = False)
adaptation_func: MaybeFactory[AnyAdaptationFunc[K, V]]
retriever_func: cbrkit.typing.RetrieverFunc[K, V, S]
multiprocessing: multiprocessing.pool.Pool | int | bool