cbrkit.synthesis.providers
1from ...helpers import optional_dependencies 2from .model import ( 3 BaseProvider, 4 ChatMessage, 5 ChatPrompt, 6 ChatProvider, 7 DocumentsPrompt, 8 Response, 9 Usage, 10) 11from .wrappers import conversation, pipe 12 13with optional_dependencies(): 14 from .openai import openai 15with optional_dependencies(): 16 from .ollama import ollama 17with optional_dependencies(): 18 from .cohere import cohere 19with optional_dependencies(): 20 from .anthropic import anthropic 21with optional_dependencies(): 22 from .instructor import instructor 23 24__all__ = [ 25 "openai", 26 "ollama", 27 "cohere", 28 "conversation", 29 "pipe", 30 "BaseProvider", 31 "ChatProvider", 32 "ChatMessage", 33 "ChatPrompt", 34 "DocumentsPrompt", 35 "Response", 36 "Usage", 37 "anthropic", 38 "instructor", 39]
28 @dataclass(slots=True) 29 class openai[R: BaseModel | str](ChatProvider[OpenaiPrompt, R]): 30 tool_choice: type[BaseModel] | str | None = None 31 client: AsyncOpenAI = field(default_factory=AsyncOpenAI, repr=False) 32 frequency_penalty: float | None = None 33 logit_bias: dict[str, int] | None = None 34 logprobs: bool | None = None 35 max_completion_tokens: int | None = None 36 metadata: dict[str, str] | None = None 37 n: int | None = None 38 presence_penalty: float | None = None 39 seed: int | None = None 40 stop: str | list[str] | None = None 41 store: bool | None = None 42 reasoning_effort: Literal["low", "medium", "high"] | None = None 43 temperature: float | None = None 44 top_logprobs: int | None = None 45 top_p: float | None = None 46 extra_headers: Any | None = None 47 extra_query: Any | None = None 48 extra_body: Any | None = None 49 timeout: float | Timeout | None = None 50 51 @override 52 async def __call_batch__(self, prompt: OpenaiPrompt) -> Response[R]: 53 messages: list[ChatCompletionMessageParam] = [] 54 55 if self.system_message is not None: 56 messages.append( 57 { 58 "role": "system", 59 "content": self.system_message, 60 } 61 ) 62 63 messages.extend(cast(Sequence[ChatCompletionMessageParam], self.messages)) 64 65 if isinstance(prompt, ChatPrompt): 66 messages.extend( 67 cast(Sequence[ChatCompletionMessageParam], prompt.messages) 68 ) 69 70 if messages and messages[-1]["role"] == "user": 71 messages.append( 72 { 73 "role": "assistant", 74 "content": unpack_value(prompt), 75 } 76 ) 77 else: 78 messages.append( 79 { 80 "role": "user", 81 "content": unpack_value(prompt), 82 } 83 ) 84 85 tools: list[ChatCompletionToolParam] | None = None 86 tool_choice: ChatCompletionNamedToolChoiceParam | None = None 87 response_type_origin = get_origin(self.response_type) 88 89 if response_type_origin is UnionType or response_type_origin is Union: 90 tools = [ 91 pydantic_function_tool(tool) 92 for tool in get_args(self.response_type) 93 if issubclass(tool, BaseModel) 94 ] 95 elif ( 96 issubclass(self.response_type, BaseModel) 97 and self.tool_choice is not None 98 ): 99 tools = [pydantic_function_tool(self.response_type)] 100 101 if self.tool_choice is not None: 102 tool_choice = { 103 "type": "function", 104 "function": { 105 "name": self.tool_choice 106 if isinstance(self.tool_choice, str) 107 else self.tool_choice.__name__, 108 }, 109 } 110 111 try: 112 res = await self.client.beta.chat.completions.parse( 113 model=self.model, 114 messages=messages, 115 response_format=self.response_type 116 if tools is None and issubclass(self.response_type, BaseModel) 117 else NOT_GIVEN, 118 tools=if_given(tools), 119 tool_choice=if_given(tool_choice), 120 frequency_penalty=if_given(self.frequency_penalty), 121 logit_bias=if_given(self.logit_bias), 122 logprobs=if_given(self.logprobs), 123 max_completion_tokens=if_given(self.max_completion_tokens), 124 metadata=if_given(self.metadata), 125 n=if_given(self.n), 126 presence_penalty=if_given(self.presence_penalty), 127 seed=if_given(self.seed), 128 stop=if_given(self.stop), 129 store=if_given(self.store), 130 reasoning_effort=if_given(self.reasoning_effort), 131 temperature=if_given(self.temperature), 132 top_logprobs=if_given(self.top_logprobs), 133 top_p=if_given(self.top_p), 134 extra_headers=self.extra_headers, 135 extra_query=self.extra_query, 136 extra_body=self.extra_body, 137 timeout=if_given(self.timeout), 138 **self.extra_kwargs, 139 ) 140 except ValidationError as e: 141 for error in e.errors(): 142 logger.error(f"Invalid response ({error['msg']}): {error['input']}") 143 raise 144 145 choice = res.choices[0] 146 message = choice.message 147 148 assert res.usage is not None 149 usage = Usage(res.usage.prompt_tokens, res.usage.completion_tokens) 150 151 if choice.finish_reason == "length": 152 raise ValueError("Length limit", res) 153 154 if choice.finish_reason == "content_filter": 155 raise ValueError("Content filter", res) 156 157 if message.refusal: 158 raise ValueError("Refusal", res) 159 160 if ( 161 isinstance(self.response_type, type) 162 and issubclass(self.response_type, BaseModel) 163 and (parsed := message.parsed) is not None 164 ): 165 return Response(cast(R, parsed), usage) 166 167 if ( 168 isinstance(self.response_type, type) 169 and issubclass(self.response_type, str) 170 and (content := message.content) is not None 171 ): 172 return Response(cast(R, content), usage) 173 174 if ( 175 tools is not None 176 and (tool_calls := message.tool_calls) is not None 177 and (parsed := tool_calls[0].function.parsed_arguments) is not None 178 ): 179 return Response(cast(R, parsed), usage) 180 181 raise ValueError("Invalid response", res)
openai(tool_choice: type[pydantic.main.BaseModel] | str | None = None, client: openai.AsyncOpenAI =
16 @dataclass(slots=True) 17 class ollama[R: str | BaseModel](ChatProvider[OllamaPrompt, R]): 18 client: AsyncClient = field(default_factory=AsyncClient, repr=False) 19 options: Options | None = None 20 keep_alive: float | str | None = None 21 22 @override 23 async def __call_batch__(self, prompt: OllamaPrompt) -> Response[R]: 24 messages: list[Message] = [] 25 26 if self.system_message is not None: 27 messages.append(Message(role="system", content=self.system_message)) 28 29 messages.extend( 30 Message(role=msg.role, content=msg.content) for msg in self.messages 31 ) 32 33 if isinstance(prompt, ChatPrompt): 34 messages.extend( 35 Message(role=msg.role, content=msg.content) 36 for msg in prompt.messages 37 ) 38 39 if self.messages and self.messages[-1].role == "user": 40 messages.append(Message(role="assistant", content=unpack_value(prompt))) 41 else: 42 messages.append(Message(role="user", content=unpack_value(prompt))) 43 44 res = await self.client.chat( 45 model=self.model, 46 messages=messages, 47 options=self.options, 48 keep_alive=self.keep_alive, 49 format=self.response_type.model_json_schema() 50 if issubclass(self.response_type, BaseModel) 51 else None, 52 **self.extra_kwargs, 53 ) 54 55 content = res["message"]["content"] 56 57 if self.response_type is str: 58 return Response(content) 59 60 return Response(json.loads(content))
27 @dataclass(slots=True) 28 class cohere[R: str | BaseModel](ChatProvider[CoherePrompt, R]): 29 client: AsyncClient = field(default_factory=AsyncClient, repr=False) 30 request_options: RequestOptions | None = None 31 citation_options: CitationOptions | None = None 32 safety_mode: V2ChatRequestSafetyMode | None = None 33 max_tokens: int | None = None 34 stop_sequences: Sequence[str] | None = None 35 temperature: float | None = None 36 seed: int | None = None 37 frequency_penalty: float | None = None 38 presence_penalty: float | None = None 39 k: float | None = None 40 p: float | None = None 41 logprobs: bool | None = None 42 43 @override 44 async def __call_batch__(self, prompt: CoherePrompt) -> Response[R]: 45 if isinstance(prompt, DocumentsPrompt) and issubclass( 46 self.response_type, BaseModel 47 ): 48 raise ValueError( 49 "Structured output format is not supported when using documents" 50 ) 51 52 messages: list[ChatMessageV2] = [] 53 54 if self.system_message is not None: 55 messages.append(SystemChatMessageV2(content=self.system_message)) 56 57 if isinstance(prompt, ChatPrompt): 58 messages.extend( 59 UserChatMessageV2(content=msg.content) 60 if msg.role == "user" 61 else AssistantChatMessageV2(content=msg.content) 62 for msg in prompt.messages 63 ) 64 65 if self.messages and self.messages[-1].role == "user": 66 messages.append(AssistantChatMessageV2(content=unpack_value(prompt))) 67 else: 68 messages.append(UserChatMessageV2(content=unpack_value(prompt))) 69 70 res = await self.client.v2.chat( 71 model=self.model, 72 messages=messages, 73 request_options=self.request_options, 74 documents=[ 75 Document(id=id, data=cast(dict[str, str], data)) 76 for id, data in prompt.documents.items() 77 ] 78 if isinstance(prompt, DocumentsPrompt) 79 else None, 80 response_format=JsonObjectResponseFormatV2( 81 json_schema=self.response_type.model_json_schema() 82 ) 83 if issubclass(self.response_type, BaseModel) 84 else None, 85 citation_options=self.citation_options, 86 safety_mode=self.safety_mode, 87 max_tokens=self.max_tokens, 88 stop_sequences=self.stop_sequences, 89 temperature=self.temperature, 90 seed=self.seed, 91 frequency_penalty=self.frequency_penalty, 92 presence_penalty=self.presence_penalty, 93 k=self.k, 94 p=self.p, 95 logprobs=self.logprobs, 96 **self.extra_kwargs, 97 ) 98 99 content = res.message.content 100 101 if content is None: 102 raise ValueError("The completion is empty") 103 104 if issubclass(self.response_type, BaseModel): 105 if len(content) != 1: 106 raise ValueError("The completion is empty or has multiple outputs") 107 108 return Response(self.response_type.model_validate_json(content[0].text)) 109 110 return Response(cast(R, "\n".join(x.text for x in content)))
22@dataclass(slots=True, frozen=True) 23class conversation[P, R](ConversionFunc[P, R]): 24 generation_func: AnyConversionFunc[ChatPrompt[P], Value[R]] 25 conversion_func: ConversionFunc[R, P] 26 chat_func: ConversionFunc[list[ChatMessage], P | None] 27 28 def __call__(self, batch: P) -> R: 29 func = unbatchify_conversion(self.generation_func) 30 31 messages: list[ChatMessage] = [ChatMessage(role="user", content=batch)] 32 last_assistant_message: R = unpack_value(func(ChatPrompt(batch, messages))) 33 messages.append( 34 ChatMessage( 35 role="assistant", 36 content=self.conversion_func(last_assistant_message), 37 ) 38 ) 39 40 while next_batch := self.chat_func(messages): 41 messages.append(ChatMessage(role="user", content=next_batch)) 42 last_assistant_message = unpack_value( 43 func(ChatPrompt(next_batch, messages)) 44 ) 45 46 messages.append( 47 ChatMessage( 48 role="assistant", 49 content=self.conversion_func(last_assistant_message), 50 ) 51 ) 52 53 return last_assistant_message
56@dataclass(slots=True, frozen=True) 57class pipe[P, R](BatchConversionFunc[P, R]): 58 generation_funcs: MaybeSequence[AnyConversionFunc[P, Value[R]]] 59 conversion_func: ConversionFunc[R, P] 60 61 def __call__(self, batches: Sequence[P]) -> Sequence[R]: 62 funcs = produce_sequence(self.generation_funcs) 63 current_input = batches 64 current_output: Sequence[R] = [] 65 66 for func in funcs: 67 batch_func = batchify_conversion(func) 68 current_output = unpack_values(batch_func(current_input)) 69 current_input = [self.conversion_func(output) for output in current_output] 70 71 if not len(current_output) == len(batches): 72 raise ValueError( 73 "The number of outputs does not match the number of inputs, " 74 "did you provie a generation function?" 75 ) 76 77 return current_output
47@dataclass(slots=True, kw_only=True) 48class BaseProvider[P, R](BatchConversionFunc[P, Response[R]], ABC): 49 model: str 50 response_type: type[R] 51 delay: float = 0 52 retries: int = 0 53 default_response: R | None = None 54 extra_kwargs: Mapping[str, Any] = field(default_factory=dict) 55 56 def __call__(self, batches: Sequence[P]) -> Sequence[Response[R]]: 57 return event_loop.get().run_until_complete(self.__call_batches__(batches)) 58 59 async def __call_batches__(self, batches: Sequence[P]) -> Sequence[Response[R]]: 60 logger.info(f"Processing {len(batches)} batches with {self.model}") 61 62 return await asyncio.gather( 63 *( 64 self.__call_batch_wrapper__(batch, idx) 65 for idx, batch in enumerate(batches) 66 ) 67 ) 68 69 async def __call_batch_wrapper__( 70 self, prompt: P, idx: int, retry: int = 0 71 ) -> Response[R]: 72 if self.delay > 0 and retry == 0: 73 await asyncio.sleep(idx * self.delay) 74 75 try: 76 result = await self.__call_batch__(prompt) 77 logger.debug(f"Result of batch {idx + 1}: {result}") 78 return result 79 80 except Exception as e: 81 if retry < self.retries: 82 logger.info(f"Retrying batch {idx + 1}...") 83 return await self.__call_batch_wrapper__(prompt, idx, retry + 1) 84 85 if self.default_response is not None: 86 logger.error(f"Error processing batch {idx + 1}: {e}") 87 return Response(self.default_response, Usage(0, 0)) 88 89 raise e 90 91 @abstractmethod 92 async def __call_batch__(self, prompt: P) -> Response[R]: ...
95@dataclass(slots=True, kw_only=True) 96class ChatProvider[P, R](BaseProvider[P, R], ABC): 97 system_message: str | None = None 98 messages: Sequence[ChatMessage] = field(default_factory=tuple)
22@dataclass(slots=True, frozen=True) 23class ChatPrompt[P](StructuredValue[P]): 24 messages: Sequence[ChatMessage[P]]
Inherited Members
27@dataclass(slots=True, frozen=True) 28class DocumentsPrompt[P](StructuredValue[P]): 29 documents: Mapping[str, Mapping[str, str]]
Inherited Members
42@dataclass(slots=True, frozen=True) 43class Response[T](StructuredValue[T]): 44 usage: Usage = Field(default_factory=Usage)
Inherited Members
57 @dataclass(slots=True) 58 class anthropic[R: str | BaseModel](ChatProvider[AnthropicPrompt, R]): 59 max_tokens: int 60 model: ModelParam 61 client: AsyncAnthropic = field(default_factory=AsyncAnthropic, repr=False) 62 metadata: MetadataParam | NotGiven = NOT_GIVEN 63 stop_sequences: list[str] | NotGiven = NOT_GIVEN 64 stream: NotGiven | Literal[False] = NOT_GIVEN 65 system: str | Iterable[TextBlockParam] | NotGiven = NOT_GIVEN 66 temperature: float | NotGiven = NOT_GIVEN 67 tool_choice: ToolChoiceParam | NotGiven = NOT_GIVEN 68 tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN 69 top_k: int | NotGiven = NOT_GIVEN 70 top_p: float | NotGiven = NOT_GIVEN 71 extra_headers: Any | None = None 72 extra_query: Any | None = None 73 extra_body: Any | None = None 74 timeout: float | Timeout | NotGiven | None = NOT_GIVEN 75 76 @override 77 async def __call_batch__(self, prompt: AnthropicPrompt) -> Response[R]: 78 messages: list[MessageParam] = [] 79 80 if self.system_message is not None: 81 messages.append( 82 { 83 "role": "user", # anthropic doesn't have a "system" role 84 "content": self.system_message, 85 } 86 ) 87 messages.extend(cast(Sequence[MessageParam], self.messages)) 88 89 if isinstance(prompt, ChatPrompt): 90 messages.extend(cast(Sequence[MessageParam], prompt.messages)) 91 92 if self.messages and self.messages[-1].role == "user": 93 messages.append({"role": "assistant", "content": unpack_value(prompt)}) 94 else: 95 messages.append({"role": "user", "content": unpack_value(prompt)}) 96 97 tool = ( 98 pydantic_to_anthropic_schema(self.response_type) 99 if issubclass(self.response_type, BaseModel) 100 else None 101 ) 102 103 toolchoice = ( 104 cast(ToolChoiceParam, {"type": "tool", "name": tool["name"]}) 105 if tool is not None 106 else None 107 ) 108 tool = cast(ToolParam, tool) if tool is not None else None 109 res = await self.client.messages.create( 110 model=self.model, 111 messages=messages, 112 max_tokens=self.max_tokens, 113 tools=[tool] if tool is not None else NOT_GIVEN, 114 tool_choice=toolchoice if toolchoice is not None else NOT_GIVEN, 115 ) 116 if issubclass(self.response_type, BaseModel): 117 # res.content should contain one ToolUseBlock 118 if len(res.content) != 1: 119 raise ValueError("Expected one ToolUseBlock, got", len(res.content)) 120 block = res.content[0] 121 if block.type != "tool_use": 122 raise ValueError("Expected one ToolUseBlock, got", block.type) 123 return Response(self.response_type.model_validate(block.input)) 124 125 elif self.response_type is str: 126 str_res = "" 127 for block in res.content: 128 if block.type == "text": 129 str_res += block.text 130 return Response(cast(R, str_res)) 131 132 raise ValueError("Invalid response", res)
19 @dataclass(slots=True) 20 class instructor[R: BaseModel](ChatProvider[InstructorPrompt, R]): 21 client: AsyncInstructor = field(repr=False) 22 strict: bool = True 23 context: dict[str, Any] | None = None 24 25 @override 26 async def __call_batch__(self, prompt: InstructorPrompt) -> Response[R]: 27 messages: list[ChatCompletionMessageParam] = [] 28 29 if self.system_message is not None: 30 messages.append( 31 { 32 "role": "system", 33 "content": self.system_message, 34 } 35 ) 36 37 messages.extend(cast(Sequence[ChatCompletionMessageParam], self.messages)) 38 39 if isinstance(prompt, ChatPrompt): 40 messages.extend( 41 cast(Sequence[ChatCompletionMessageParam], prompt.messages) 42 ) 43 44 if messages and messages[-1]["role"] == "user": 45 messages.append( 46 { 47 "role": "assistant", 48 "content": unpack_value(prompt), 49 } 50 ) 51 else: 52 messages.append( 53 { 54 "role": "user", 55 "content": unpack_value(prompt), 56 } 57 ) 58 59 # retries are already handled by the base provider 60 return Response( 61 await self.client.chat.completions.create( 62 model=self.model, 63 messages=messages, 64 response_model=self.response_type, 65 context=self.context, 66 **self.extra_kwargs, 67 ) 68 )