cbrkit.reuse.build

  1import itertools
  2from collections.abc import Sequence
  3from dataclasses import dataclass
  4from inspect import signature as inspect_signature
  5from multiprocessing.pool import Pool
  6from typing import cast, override
  7
  8from ..helpers import (
  9    batchify_adaptation,
 10    batchify_sim,
 11    chunkify,
 12    get_logger,
 13    mp_count,
 14    mp_map,
 15    mp_starmap,
 16    produce_factory,
 17    use_mp,
 18)
 19from ..typing import (
 20    AnyAdaptationFunc,
 21    AnySimFunc,
 22    Casebase,
 23    Float,
 24    MapAdaptationFunc,
 25    MaybeFactory,
 26    ReduceAdaptationFunc,
 27    ReuserFunc,
 28    SimMap,
 29    SimpleAdaptationFunc,
 30)
 31
 32logger = get_logger(__name__)
 33
 34__all__ = ["build"]
 35
 36
 37@dataclass(slots=True, frozen=True)
 38class build[K, V, S: Float](ReuserFunc[K, V, S]):
 39    """Builds a casebase by adapting cases using an adaptation function and a similarity function.
 40
 41    Args:
 42        adaptation_func: The adaptation function that will be applied to the cases.
 43        similarity_func: The similarity function that will be used to compare the adapted cases to the query.
 44        multiprocessing: Multiprocessing configuration for adaptation.
 45        chunksize: Number of batches to process at a time using the adaptation function.
 46            If 0, it will be set to the number of batches divided by the number of processes.
 47
 48    Returns:
 49        The adapted casebases and the similarities between the adapted cases and the query.
 50    """
 51
 52    adaptation_func: MaybeFactory[AnyAdaptationFunc[K, V]]
 53    similarity_func: MaybeFactory[AnySimFunc[V, S]]
 54    multiprocessing: Pool | int | bool = False
 55    chunksize: int = 0
 56
 57    @override
 58    def __call__(
 59        self,
 60        batches: Sequence[tuple[Casebase[K, V], V]],
 61    ) -> Sequence[tuple[Casebase[K, V], SimMap[K, S]]]:
 62        adaptation_func = produce_factory(self.adaptation_func)
 63        adapted_casebases = self._adapt(batches, adaptation_func)
 64        adapted_batches = [
 65            (adapted_casebase, query)
 66            for adapted_casebase, (_, query) in zip(
 67                adapted_casebases, batches, strict=True
 68            )
 69        ]
 70
 71        # Score adapted cases against queries
 72        sim_func = batchify_sim(produce_factory(self.similarity_func))
 73
 74        flat_batches: list[tuple[V, V]] = []
 75        flat_index: list[tuple[int, K]] = []
 76
 77        for idx, (casebase, query) in enumerate(adapted_batches):
 78            for key, case in casebase.items():
 79                flat_index.append((idx, key))
 80                flat_batches.append((case, query))
 81
 82        scores = sim_func(flat_batches)
 83
 84        sim_maps: list[dict[K, S]] = [{} for _ in adapted_batches]
 85        for (idx, key), score in zip(flat_index, scores, strict=True):
 86            sim_maps[idx][key] = score
 87
 88        return [
 89            (adapted_casebase, sim_map)
 90            for adapted_casebase, sim_map in zip(
 91                adapted_casebases, sim_maps, strict=True
 92            )
 93        ]
 94
 95    def _adapt(
 96        self,
 97        batches: Sequence[tuple[Casebase[K, V], V]],
 98        adaptation_func: AnyAdaptationFunc[K, V],
 99    ) -> Sequence[Casebase[K, V]]:
100        adaptation_func_signature = inspect_signature(adaptation_func)
101
102        if "casebase" in adaptation_func_signature.parameters:
103            adapt_func = cast(
104                MapAdaptationFunc[K, V] | ReduceAdaptationFunc[K, V],
105                adaptation_func,
106            )
107            adaptation_results = mp_starmap(
108                adapt_func,
109                batches,
110                self.multiprocessing,
111                logger,
112            )
113
114            if all(isinstance(item, tuple) for item in adaptation_results):
115                adaptation_results = cast(Sequence[tuple[K, V]], adaptation_results)
116                return [
117                    {adapted_key: adapted_case}
118                    for adapted_key, adapted_case in adaptation_results
119                ]
120
121            return cast(Sequence[Casebase[K, V]], adaptation_results)
122
123        adapt_func = batchify_adaptation(cast(SimpleAdaptationFunc[V], adaptation_func))
124        batches_index: list[tuple[int, K]] = []
125        flat_batches: list[tuple[V, V]] = []
126
127        for idx, (casebase, query) in enumerate(batches):
128            for key, case in casebase.items():
129                batches_index.append((idx, key))
130                flat_batches.append((case, query))
131
132        adapted_cases: Sequence[V]
133
134        if use_mp(self.multiprocessing) or self.chunksize > 0:
135            chunksize = (
136                self.chunksize
137                if self.chunksize > 0
138                else len(flat_batches) // mp_count(self.multiprocessing)
139            )
140            batch_chunks = list(chunkify(flat_batches, chunksize))
141            adapted_chunks = mp_map(
142                adapt_func, batch_chunks, self.multiprocessing, logger
143            )
144            adapted_cases = list(itertools.chain.from_iterable(adapted_chunks))
145        else:
146            adapted_cases = list(adapt_func(flat_batches))
147
148        adapted_casebases: list[dict[K, V]] = [{} for _ in range(len(batches))]
149
150        for (idx, key), adapted_case in zip(batches_index, adapted_cases, strict=True):
151            adapted_casebases[idx][key] = adapted_case
152
153        return adapted_casebases
@dataclass(slots=True, frozen=True)
class build(cbrkit.typing.ReuserFunc[K, V, S], typing.Generic[K, V, S]):
 38@dataclass(slots=True, frozen=True)
 39class build[K, V, S: Float](ReuserFunc[K, V, S]):
 40    """Builds a casebase by adapting cases using an adaptation function and a similarity function.
 41
 42    Args:
 43        adaptation_func: The adaptation function that will be applied to the cases.
 44        similarity_func: The similarity function that will be used to compare the adapted cases to the query.
 45        multiprocessing: Multiprocessing configuration for adaptation.
 46        chunksize: Number of batches to process at a time using the adaptation function.
 47            If 0, it will be set to the number of batches divided by the number of processes.
 48
 49    Returns:
 50        The adapted casebases and the similarities between the adapted cases and the query.
 51    """
 52
 53    adaptation_func: MaybeFactory[AnyAdaptationFunc[K, V]]
 54    similarity_func: MaybeFactory[AnySimFunc[V, S]]
 55    multiprocessing: Pool | int | bool = False
 56    chunksize: int = 0
 57
 58    @override
 59    def __call__(
 60        self,
 61        batches: Sequence[tuple[Casebase[K, V], V]],
 62    ) -> Sequence[tuple[Casebase[K, V], SimMap[K, S]]]:
 63        adaptation_func = produce_factory(self.adaptation_func)
 64        adapted_casebases = self._adapt(batches, adaptation_func)
 65        adapted_batches = [
 66            (adapted_casebase, query)
 67            for adapted_casebase, (_, query) in zip(
 68                adapted_casebases, batches, strict=True
 69            )
 70        ]
 71
 72        # Score adapted cases against queries
 73        sim_func = batchify_sim(produce_factory(self.similarity_func))
 74
 75        flat_batches: list[tuple[V, V]] = []
 76        flat_index: list[tuple[int, K]] = []
 77
 78        for idx, (casebase, query) in enumerate(adapted_batches):
 79            for key, case in casebase.items():
 80                flat_index.append((idx, key))
 81                flat_batches.append((case, query))
 82
 83        scores = sim_func(flat_batches)
 84
 85        sim_maps: list[dict[K, S]] = [{} for _ in adapted_batches]
 86        for (idx, key), score in zip(flat_index, scores, strict=True):
 87            sim_maps[idx][key] = score
 88
 89        return [
 90            (adapted_casebase, sim_map)
 91            for adapted_casebase, sim_map in zip(
 92                adapted_casebases, sim_maps, strict=True
 93            )
 94        ]
 95
 96    def _adapt(
 97        self,
 98        batches: Sequence[tuple[Casebase[K, V], V]],
 99        adaptation_func: AnyAdaptationFunc[K, V],
100    ) -> Sequence[Casebase[K, V]]:
101        adaptation_func_signature = inspect_signature(adaptation_func)
102
103        if "casebase" in adaptation_func_signature.parameters:
104            adapt_func = cast(
105                MapAdaptationFunc[K, V] | ReduceAdaptationFunc[K, V],
106                adaptation_func,
107            )
108            adaptation_results = mp_starmap(
109                adapt_func,
110                batches,
111                self.multiprocessing,
112                logger,
113            )
114
115            if all(isinstance(item, tuple) for item in adaptation_results):
116                adaptation_results = cast(Sequence[tuple[K, V]], adaptation_results)
117                return [
118                    {adapted_key: adapted_case}
119                    for adapted_key, adapted_case in adaptation_results
120                ]
121
122            return cast(Sequence[Casebase[K, V]], adaptation_results)
123
124        adapt_func = batchify_adaptation(cast(SimpleAdaptationFunc[V], adaptation_func))
125        batches_index: list[tuple[int, K]] = []
126        flat_batches: list[tuple[V, V]] = []
127
128        for idx, (casebase, query) in enumerate(batches):
129            for key, case in casebase.items():
130                batches_index.append((idx, key))
131                flat_batches.append((case, query))
132
133        adapted_cases: Sequence[V]
134
135        if use_mp(self.multiprocessing) or self.chunksize > 0:
136            chunksize = (
137                self.chunksize
138                if self.chunksize > 0
139                else len(flat_batches) // mp_count(self.multiprocessing)
140            )
141            batch_chunks = list(chunkify(flat_batches, chunksize))
142            adapted_chunks = mp_map(
143                adapt_func, batch_chunks, self.multiprocessing, logger
144            )
145            adapted_cases = list(itertools.chain.from_iterable(adapted_chunks))
146        else:
147            adapted_cases = list(adapt_func(flat_batches))
148
149        adapted_casebases: list[dict[K, V]] = [{} for _ in range(len(batches))]
150
151        for (idx, key), adapted_case in zip(batches_index, adapted_cases, strict=True):
152            adapted_casebases[idx][key] = adapted_case
153
154        return adapted_casebases

Builds a casebase by adapting cases using an adaptation function and a similarity function.

Arguments:
  • adaptation_func: The adaptation function that will be applied to the cases.
  • similarity_func: The similarity function that will be used to compare the adapted cases to the query.
  • multiprocessing: Multiprocessing configuration for adaptation.
  • chunksize: Number of batches to process at a time using the adaptation function. If 0, it will be set to the number of batches divided by the number of processes.
Returns:

The adapted casebases and the similarities between the adapted cases and the query.

build( adaptation_func: MaybeFactory[AnyAdaptationFunc[K, V]], similarity_func: MaybeFactory[AnySimFunc[V, S]], multiprocessing: multiprocessing.pool.Pool | int | bool = False, chunksize: int = 0)
adaptation_func: MaybeFactory[AnyAdaptationFunc[K, V]]
similarity_func: MaybeFactory[AnySimFunc[V, S]]
multiprocessing: multiprocessing.pool.Pool | int | bool
chunksize: int