cbrkit.synthesis.providers

 1from ...helpers import optional_dependencies
 2from .model import (
 3    BaseProvider,
 4    ChatMessage,
 5    ChatPrompt,
 6    ChatProvider,
 7    DocumentsPrompt,
 8    Response,
 9    Usage,
10)
11from .wrappers import conversation, pipe
12
13with optional_dependencies():
14    from .openai import openai
15with optional_dependencies():
16    from .ollama import ollama
17with optional_dependencies():
18    from .cohere import cohere
19with optional_dependencies():
20    from .anthropic import anthropic
21with optional_dependencies():
22    from .instructor import instructor
23
24__all__ = [
25    "openai",
26    "ollama",
27    "cohere",
28    "conversation",
29    "pipe",
30    "BaseProvider",
31    "ChatProvider",
32    "ChatMessage",
33    "ChatPrompt",
34    "DocumentsPrompt",
35    "Response",
36    "Usage",
37    "anthropic",
38    "instructor",
39]
@dataclass(slots=True)
class openai(cbrkit.synthesis.providers.ChatProvider[OpenaiPrompt, R], typing.Generic[R]):
 28    @dataclass(slots=True)
 29    class openai[R: BaseModel | str](ChatProvider[OpenaiPrompt, R]):
 30        tool_choice: type[BaseModel] | str | None = None
 31        client: AsyncOpenAI = field(default_factory=AsyncOpenAI, repr=False)
 32        frequency_penalty: float | None = None
 33        logit_bias: dict[str, int] | None = None
 34        logprobs: bool | None = None
 35        max_completion_tokens: int | None = None
 36        metadata: dict[str, str] | None = None
 37        n: int | None = None
 38        presence_penalty: float | None = None
 39        seed: int | None = None
 40        stop: str | list[str] | None = None
 41        store: bool | None = None
 42        reasoning_effort: Literal["low", "medium", "high"] | None = None
 43        temperature: float | None = None
 44        top_logprobs: int | None = None
 45        top_p: float | None = None
 46        extra_headers: Any | None = None
 47        extra_query: Any | None = None
 48        extra_body: Any | None = None
 49        timeout: float | Timeout | None = None
 50
 51        @override
 52        async def __call_batch__(self, prompt: OpenaiPrompt) -> Response[R]:
 53            messages: list[ChatCompletionMessageParam] = []
 54
 55            if self.system_message is not None:
 56                messages.append(
 57                    {
 58                        "role": "system",
 59                        "content": self.system_message,
 60                    }
 61                )
 62
 63            messages.extend(cast(Sequence[ChatCompletionMessageParam], self.messages))
 64
 65            if isinstance(prompt, ChatPrompt):
 66                messages.extend(
 67                    cast(Sequence[ChatCompletionMessageParam], prompt.messages)
 68                )
 69
 70            if messages and messages[-1]["role"] == "user":
 71                messages.append(
 72                    {
 73                        "role": "assistant",
 74                        "content": unpack_value(prompt),
 75                    }
 76                )
 77            else:
 78                messages.append(
 79                    {
 80                        "role": "user",
 81                        "content": unpack_value(prompt),
 82                    }
 83                )
 84
 85            tools: list[ChatCompletionToolParam] | None = None
 86            tool_choice: ChatCompletionNamedToolChoiceParam | None = None
 87            response_type_origin = get_origin(self.response_type)
 88
 89            if response_type_origin is UnionType or response_type_origin is Union:
 90                tools = [
 91                    pydantic_function_tool(tool)
 92                    for tool in get_args(self.response_type)
 93                    if issubclass(tool, BaseModel)
 94                ]
 95            elif (
 96                issubclass(self.response_type, BaseModel)
 97                and self.tool_choice is not None
 98            ):
 99                tools = [pydantic_function_tool(self.response_type)]
100
101            if self.tool_choice is not None:
102                tool_choice = {
103                    "type": "function",
104                    "function": {
105                        "name": self.tool_choice
106                        if isinstance(self.tool_choice, str)
107                        else self.tool_choice.__name__,
108                    },
109                }
110
111            try:
112                res = await self.client.beta.chat.completions.parse(
113                    model=self.model,
114                    messages=messages,
115                    response_format=self.response_type
116                    if tools is None and issubclass(self.response_type, BaseModel)
117                    else NOT_GIVEN,
118                    tools=if_given(tools),
119                    tool_choice=if_given(tool_choice),
120                    frequency_penalty=if_given(self.frequency_penalty),
121                    logit_bias=if_given(self.logit_bias),
122                    logprobs=if_given(self.logprobs),
123                    max_completion_tokens=if_given(self.max_completion_tokens),
124                    metadata=if_given(self.metadata),
125                    n=if_given(self.n),
126                    presence_penalty=if_given(self.presence_penalty),
127                    seed=if_given(self.seed),
128                    stop=if_given(self.stop),
129                    store=if_given(self.store),
130                    reasoning_effort=if_given(self.reasoning_effort),
131                    temperature=if_given(self.temperature),
132                    top_logprobs=if_given(self.top_logprobs),
133                    top_p=if_given(self.top_p),
134                    extra_headers=self.extra_headers,
135                    extra_query=self.extra_query,
136                    extra_body=self.extra_body,
137                    timeout=if_given(self.timeout),
138                    **self.extra_kwargs,
139                )
140            except ValidationError as e:
141                for error in e.errors():
142                    logger.error(f"Invalid response ({error['msg']}): {error['input']}")
143                raise
144
145            choice = res.choices[0]
146            message = choice.message
147
148            assert res.usage is not None
149            usage = Usage(res.usage.prompt_tokens, res.usage.completion_tokens)
150
151            if choice.finish_reason == "length":
152                raise ValueError("Length limit", res)
153
154            if choice.finish_reason == "content_filter":
155                raise ValueError("Content filter", res)
156
157            if message.refusal:
158                raise ValueError("Refusal", res)
159
160            if (
161                isinstance(self.response_type, type)
162                and issubclass(self.response_type, BaseModel)
163                and (parsed := message.parsed) is not None
164            ):
165                return Response(cast(R, parsed), usage)
166
167            if (
168                isinstance(self.response_type, type)
169                and issubclass(self.response_type, str)
170                and (content := message.content) is not None
171            ):
172                return Response(cast(R, content), usage)
173
174            if (
175                tools is not None
176                and (tool_calls := message.tool_calls) is not None
177                and (parsed := tool_calls[0].function.parsed_arguments) is not None
178            ):
179                return Response(cast(R, parsed), usage)
180
181            raise ValueError("Invalid response", res)

openai(tool_choice: type[pydantic.main.BaseModel] | str | None = None, client: openai.AsyncOpenAI = , frequency_penalty: float | None = None, logit_bias: dict[str, int] | None = None, logprobs: bool | None = None, max_completion_tokens: int | None = None, metadata: dict[str, str] | None = None, n: int | None = None, presence_penalty: float | None = None, seed: int | None = None, stop: str | list[str] | None = None, store: bool | None = None, reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None, temperature: float | None = None, top_logprobs: int | None = None, top_p: float | None = None, extra_headers: typing.Any | None = None, extra_query: typing.Any | None = None, extra_body: typing.Any | None = None, timeout: float | openai.Timeout | None = None, *, model: str, response_type: type[R], delay: float = 0, retries: int = 0, default_response: Optional[R] = None, extra_kwargs: collections.abc.Mapping[str, typing.Any] = , system_message: str | None = None, messages: collections.abc.Sequence[cbrkit.synthesis.providers.model.ChatMessage] = )

openai( tool_choice: type[pydantic.main.BaseModel] | str | None = None, client: openai.AsyncOpenAI = <factory>, frequency_penalty: float | None = None, logit_bias: dict[str, int] | None = None, logprobs: bool | None = None, max_completion_tokens: int | None = None, metadata: dict[str, str] | None = None, n: int | None = None, presence_penalty: float | None = None, seed: int | None = None, stop: str | list[str] | None = None, store: bool | None = None, reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None, temperature: float | None = None, top_logprobs: int | None = None, top_p: float | None = None, extra_headers: typing.Any | None = None, extra_query: typing.Any | None = None, extra_body: typing.Any | None = None, timeout: float | anthropic.Timeout | None = None, *, model: str, response_type: type[R], delay: float = 0, retries: int = 0, default_response: Optional[R] = None, extra_kwargs: Mapping[str, typing.Any] = <factory>, system_message: str | None = None, messages: Sequence[ChatMessage] = <factory>)
tool_choice: type[pydantic.main.BaseModel] | str | None
client: openai.AsyncOpenAI
frequency_penalty: float | None
logit_bias: dict[str, int] | None
logprobs: bool | None
max_completion_tokens: int | None
metadata: dict[str, str] | None
n: int | None
presence_penalty: float | None
seed: int | None
stop: str | list[str] | None
store: bool | None
reasoning_effort: Optional[Literal['low', 'medium', 'high']]
temperature: float | None
top_logprobs: int | None
top_p: float | None
extra_headers: typing.Any | None
extra_query: typing.Any | None
extra_body: typing.Any | None
timeout: float | anthropic.Timeout | None
@dataclass(slots=True)
class ollama(cbrkit.synthesis.providers.ChatProvider[OllamaPrompt, R], typing.Generic[R]):
16    @dataclass(slots=True)
17    class ollama[R: str | BaseModel](ChatProvider[OllamaPrompt, R]):
18        client: AsyncClient = field(default_factory=AsyncClient, repr=False)
19        options: Options | None = None
20        keep_alive: float | str | None = None
21
22        @override
23        async def __call_batch__(self, prompt: OllamaPrompt) -> Response[R]:
24            messages: list[Message] = []
25
26            if self.system_message is not None:
27                messages.append(Message(role="system", content=self.system_message))
28
29            messages.extend(
30                Message(role=msg.role, content=msg.content) for msg in self.messages
31            )
32
33            if isinstance(prompt, ChatPrompt):
34                messages.extend(
35                    Message(role=msg.role, content=msg.content)
36                    for msg in prompt.messages
37                )
38
39            if self.messages and self.messages[-1].role == "user":
40                messages.append(Message(role="assistant", content=unpack_value(prompt)))
41            else:
42                messages.append(Message(role="user", content=unpack_value(prompt)))
43
44            res = await self.client.chat(
45                model=self.model,
46                messages=messages,
47                options=self.options,
48                keep_alive=self.keep_alive,
49                format=self.response_type.model_json_schema()
50                if issubclass(self.response_type, BaseModel)
51                else None,
52                **self.extra_kwargs,
53            )
54
55            content = res["message"]["content"]
56
57            if self.response_type is str:
58                return Response(content)
59
60            return Response(json.loads(content))
ollama( client: ollama._client.AsyncClient = <factory>, options: ollama._types.Options | None = None, keep_alive: float | str | None = None, *, model: str, response_type: type[R], delay: float = 0, retries: int = 0, default_response: Optional[R] = None, extra_kwargs: Mapping[str, typing.Any] = <factory>, system_message: str | None = None, messages: Sequence[ChatMessage] = <factory>)
client: ollama._client.AsyncClient
options: ollama._types.Options | None
keep_alive: float | str | None
@dataclass(slots=True)
class cohere(cbrkit.synthesis.providers.ChatProvider[CoherePrompt, R], typing.Generic[R]):
 27    @dataclass(slots=True)
 28    class cohere[R: str | BaseModel](ChatProvider[CoherePrompt, R]):
 29        client: AsyncClient = field(default_factory=AsyncClient, repr=False)
 30        request_options: RequestOptions | None = None
 31        citation_options: CitationOptions | None = None
 32        safety_mode: V2ChatRequestSafetyMode | None = None
 33        max_tokens: int | None = None
 34        stop_sequences: Sequence[str] | None = None
 35        temperature: float | None = None
 36        seed: int | None = None
 37        frequency_penalty: float | None = None
 38        presence_penalty: float | None = None
 39        k: float | None = None
 40        p: float | None = None
 41        logprobs: bool | None = None
 42
 43        @override
 44        async def __call_batch__(self, prompt: CoherePrompt) -> Response[R]:
 45            if isinstance(prompt, DocumentsPrompt) and issubclass(
 46                self.response_type, BaseModel
 47            ):
 48                raise ValueError(
 49                    "Structured output format is not supported when using documents"
 50                )
 51
 52            messages: list[ChatMessageV2] = []
 53
 54            if self.system_message is not None:
 55                messages.append(SystemChatMessageV2(content=self.system_message))
 56
 57            if isinstance(prompt, ChatPrompt):
 58                messages.extend(
 59                    UserChatMessageV2(content=msg.content)
 60                    if msg.role == "user"
 61                    else AssistantChatMessageV2(content=msg.content)
 62                    for msg in prompt.messages
 63                )
 64
 65            if self.messages and self.messages[-1].role == "user":
 66                messages.append(AssistantChatMessageV2(content=unpack_value(prompt)))
 67            else:
 68                messages.append(UserChatMessageV2(content=unpack_value(prompt)))
 69
 70            res = await self.client.v2.chat(
 71                model=self.model,
 72                messages=messages,
 73                request_options=self.request_options,
 74                documents=[
 75                    Document(id=id, data=cast(dict[str, str], data))
 76                    for id, data in prompt.documents.items()
 77                ]
 78                if isinstance(prompt, DocumentsPrompt)
 79                else None,
 80                response_format=JsonObjectResponseFormatV2(
 81                    json_schema=self.response_type.model_json_schema()
 82                )
 83                if issubclass(self.response_type, BaseModel)
 84                else None,
 85                citation_options=self.citation_options,
 86                safety_mode=self.safety_mode,
 87                max_tokens=self.max_tokens,
 88                stop_sequences=self.stop_sequences,
 89                temperature=self.temperature,
 90                seed=self.seed,
 91                frequency_penalty=self.frequency_penalty,
 92                presence_penalty=self.presence_penalty,
 93                k=self.k,
 94                p=self.p,
 95                logprobs=self.logprobs,
 96                **self.extra_kwargs,
 97            )
 98
 99            content = res.message.content
100
101            if content is None:
102                raise ValueError("The completion is empty")
103
104            if issubclass(self.response_type, BaseModel):
105                if len(content) != 1:
106                    raise ValueError("The completion is empty or has multiple outputs")
107
108                return Response(self.response_type.model_validate_json(content[0].text))
109
110            return Response(cast(R, "\n".join(x.text for x in content)))
cohere( client: cohere.client.AsyncClient = <factory>, request_options: cohere.core.request_options.RequestOptions | None = None, citation_options: cohere.types.citation_options.CitationOptions | None = None, safety_mode: Union[Literal['CONTEXTUAL', 'STRICT', 'OFF'], Any, NoneType] = None, max_tokens: int | None = None, stop_sequences: Sequence[str] | None = None, temperature: float | None = None, seed: int | None = None, frequency_penalty: float | None = None, presence_penalty: float | None = None, k: float | None = None, p: float | None = None, logprobs: bool | None = None, *, model: str, response_type: type[R], delay: float = 0, retries: int = 0, default_response: Optional[R] = None, extra_kwargs: Mapping[str, typing.Any] = <factory>, system_message: str | None = None, messages: Sequence[ChatMessage] = <factory>)
client: cohere.client.AsyncClient
request_options: cohere.core.request_options.RequestOptions | None
citation_options: cohere.types.citation_options.CitationOptions | None
safety_mode: Union[Literal['CONTEXTUAL', 'STRICT', 'OFF'], Any, NoneType]
max_tokens: int | None
stop_sequences: Sequence[str] | None
temperature: float | None
seed: int | None
frequency_penalty: float | None
presence_penalty: float | None
k: float | None
p: float | None
logprobs: bool | None
@dataclass(slots=True, frozen=True)
class conversation(cbrkit.typing.ConversionFunc[P, R], typing.Generic[P, R]):
22@dataclass(slots=True, frozen=True)
23class conversation[P, R](ConversionFunc[P, R]):
24    generation_func: AnyConversionFunc[ChatPrompt[P], Value[R]]
25    conversion_func: ConversionFunc[R, P]
26    chat_func: ConversionFunc[list[ChatMessage], P | None]
27
28    def __call__(self, batch: P) -> R:
29        func = unbatchify_conversion(self.generation_func)
30
31        messages: list[ChatMessage] = [ChatMessage(role="user", content=batch)]
32        last_assistant_message: R = unpack_value(func(ChatPrompt(batch, messages)))
33        messages.append(
34            ChatMessage(
35                role="assistant",
36                content=self.conversion_func(last_assistant_message),
37            )
38        )
39
40        while next_batch := self.chat_func(messages):
41            messages.append(ChatMessage(role="user", content=next_batch))
42            last_assistant_message = unpack_value(
43                func(ChatPrompt(next_batch, messages))
44            )
45
46            messages.append(
47                ChatMessage(
48                    role="assistant",
49                    content=self.conversion_func(last_assistant_message),
50                )
51            )
52
53        return last_assistant_message
conversation( generation_func: AnyConversionFunc[ChatPrompt[P], Value[R]], conversion_func: cbrkit.typing.ConversionFunc[R, P], chat_func: cbrkit.typing.ConversionFunc[list[ChatMessage], typing.Optional[P]])
generation_func: AnyConversionFunc[ChatPrompt[P], Value[R]]
conversion_func: cbrkit.typing.ConversionFunc[R, P]
chat_func: cbrkit.typing.ConversionFunc[list[ChatMessage], typing.Optional[P]]
@dataclass(slots=True, frozen=True)
class pipe(cbrkit.typing.BatchConversionFunc[P, R], typing.Generic[P, R]):
56@dataclass(slots=True, frozen=True)
57class pipe[P, R](BatchConversionFunc[P, R]):
58    generation_funcs: MaybeSequence[AnyConversionFunc[P, Value[R]]]
59    conversion_func: ConversionFunc[R, P]
60
61    def __call__(self, batches: Sequence[P]) -> Sequence[R]:
62        funcs = produce_sequence(self.generation_funcs)
63        current_input = batches
64        current_output: Sequence[R] = []
65
66        for func in funcs:
67            batch_func = batchify_conversion(func)
68            current_output = unpack_values(batch_func(current_input))
69            current_input = [self.conversion_func(output) for output in current_output]
70
71        if not len(current_output) == len(batches):
72            raise ValueError(
73                "The number of outputs does not match the number of inputs, "
74                "did you provie a generation function?"
75            )
76
77        return current_output
pipe( generation_funcs: MaybeSequence[AnyConversionFunc[P, Value[R]]], conversion_func: cbrkit.typing.ConversionFunc[R, P])
generation_funcs: MaybeSequence[AnyConversionFunc[P, Value[R]]]
conversion_func: cbrkit.typing.ConversionFunc[R, P]
@dataclass(slots=True, kw_only=True)
class BaseProvider(cbrkit.typing.BatchConversionFunc[P, cbrkit.synthesis.providers.model.Response[R]], abc.ABC, typing.Generic[P, R]):
47@dataclass(slots=True, kw_only=True)
48class BaseProvider[P, R](BatchConversionFunc[P, Response[R]], ABC):
49    model: str
50    response_type: type[R]
51    delay: float = 0
52    retries: int = 0
53    default_response: R | None = None
54    extra_kwargs: Mapping[str, Any] = field(default_factory=dict)
55
56    def __call__(self, batches: Sequence[P]) -> Sequence[Response[R]]:
57        return event_loop.get().run_until_complete(self.__call_batches__(batches))
58
59    async def __call_batches__(self, batches: Sequence[P]) -> Sequence[Response[R]]:
60        logger.info(f"Processing {len(batches)} batches with {self.model}")
61
62        return await asyncio.gather(
63            *(
64                self.__call_batch_wrapper__(batch, idx)
65                for idx, batch in enumerate(batches)
66            )
67        )
68
69    async def __call_batch_wrapper__(
70        self, prompt: P, idx: int, retry: int = 0
71    ) -> Response[R]:
72        if self.delay > 0 and retry == 0:
73            await asyncio.sleep(idx * self.delay)
74
75        try:
76            result = await self.__call_batch__(prompt)
77            logger.debug(f"Result of batch {idx + 1}: {result}")
78            return result
79
80        except Exception as e:
81            if retry < self.retries:
82                logger.info(f"Retrying batch {idx + 1}...")
83                return await self.__call_batch_wrapper__(prompt, idx, retry + 1)
84
85            if self.default_response is not None:
86                logger.error(f"Error processing batch {idx + 1}: {e}")
87                return Response(self.default_response, Usage(0, 0))
88
89            raise e
90
91    @abstractmethod
92    async def __call_batch__(self, prompt: P) -> Response[R]: ...
model: str
response_type: type[R]
delay: float
retries: int
default_response: Optional[R]
extra_kwargs: Mapping[str, typing.Any]
@dataclass(slots=True, kw_only=True)
class ChatProvider(cbrkit.synthesis.providers.BaseProvider[P, R], abc.ABC, typing.Generic[P, R]):
95@dataclass(slots=True, kw_only=True)
96class ChatProvider[P, R](BaseProvider[P, R], ABC):
97    system_message: str | None = None
98    messages: Sequence[ChatMessage] = field(default_factory=tuple)
system_message: str | None
messages: Sequence[ChatMessage]
@dataclass(slots=True, frozen=True)
class ChatMessage(typing.Generic[P]):
16@dataclass(slots=True, frozen=True)
17class ChatMessage[P]:
18    role: Literal["user", "assistant"]
19    content: P
ChatMessage(role: Literal['user', 'assistant'], content: P)
role: Literal['user', 'assistant']
content: P
@dataclass(slots=True, frozen=True)
class ChatPrompt(cbrkit.typing.StructuredValue[P], typing.Generic[P]):
22@dataclass(slots=True, frozen=True)
23class ChatPrompt[P](StructuredValue[P]):
24    messages: Sequence[ChatMessage[P]]
ChatPrompt( value: T, messages: Sequence[ChatMessage[P]])
messages: Sequence[ChatMessage[P]]
@dataclass(slots=True, frozen=True)
class DocumentsPrompt(cbrkit.typing.StructuredValue[P], typing.Generic[P]):
27@dataclass(slots=True, frozen=True)
28class DocumentsPrompt[P](StructuredValue[P]):
29    documents: Mapping[str, Mapping[str, str]]
DocumentsPrompt(value: T, documents: Mapping[str, Mapping[str, str]])
documents: Mapping[str, Mapping[str, str]]
@dataclass(slots=True, frozen=True)
class Response(cbrkit.typing.StructuredValue[T], typing.Generic[T]):
42@dataclass(slots=True, frozen=True)
43class Response[T](StructuredValue[T]):
44    usage: Usage = Field(default_factory=Usage)
Response( value: T, usage: Usage = FieldInfo(annotation=NoneType, required=False, default_factory=Usage))
usage: Usage
@dataclass(slots=True, frozen=True)
class Usage:
32@dataclass(slots=True, frozen=True)
33class Usage:
34    prompt_tokens: int = 0
35    completion_tokens: int = 0
36
37    @property
38    def total_tokens(self) -> int:
39        return self.prompt_tokens + self.completion_tokens
Usage(prompt_tokens: int = 0, completion_tokens: int = 0)
prompt_tokens: int
completion_tokens: int
total_tokens: int
37    @property
38    def total_tokens(self) -> int:
39        return self.prompt_tokens + self.completion_tokens
@dataclass(slots=True)
class anthropic(cbrkit.synthesis.providers.ChatProvider[AnthropicPrompt, R], typing.Generic[R]):
 57    @dataclass(slots=True)
 58    class anthropic[R: str | BaseModel](ChatProvider[AnthropicPrompt, R]):
 59        max_tokens: int
 60        model: ModelParam
 61        client: AsyncAnthropic = field(default_factory=AsyncAnthropic, repr=False)
 62        metadata: MetadataParam | NotGiven = NOT_GIVEN
 63        stop_sequences: list[str] | NotGiven = NOT_GIVEN
 64        stream: NotGiven | Literal[False] = NOT_GIVEN
 65        system: str | Iterable[TextBlockParam] | NotGiven = NOT_GIVEN
 66        temperature: float | NotGiven = NOT_GIVEN
 67        tool_choice: ToolChoiceParam | NotGiven = NOT_GIVEN
 68        tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN
 69        top_k: int | NotGiven = NOT_GIVEN
 70        top_p: float | NotGiven = NOT_GIVEN
 71        extra_headers: Any | None = None
 72        extra_query: Any | None = None
 73        extra_body: Any | None = None
 74        timeout: float | Timeout | NotGiven | None = NOT_GIVEN
 75
 76        @override
 77        async def __call_batch__(self, prompt: AnthropicPrompt) -> Response[R]:
 78            messages: list[MessageParam] = []
 79
 80            if self.system_message is not None:
 81                messages.append(
 82                    {
 83                        "role": "user",  # anthropic doesn't have a "system" role
 84                        "content": self.system_message,
 85                    }
 86                )
 87            messages.extend(cast(Sequence[MessageParam], self.messages))
 88
 89            if isinstance(prompt, ChatPrompt):
 90                messages.extend(cast(Sequence[MessageParam], prompt.messages))
 91
 92            if self.messages and self.messages[-1].role == "user":
 93                messages.append({"role": "assistant", "content": unpack_value(prompt)})
 94            else:
 95                messages.append({"role": "user", "content": unpack_value(prompt)})
 96
 97            tool = (
 98                pydantic_to_anthropic_schema(self.response_type)
 99                if issubclass(self.response_type, BaseModel)
100                else None
101            )
102
103            toolchoice = (
104                cast(ToolChoiceParam, {"type": "tool", "name": tool["name"]})
105                if tool is not None
106                else None
107            )
108            tool = cast(ToolParam, tool) if tool is not None else None
109            res = await self.client.messages.create(
110                model=self.model,
111                messages=messages,
112                max_tokens=self.max_tokens,
113                tools=[tool] if tool is not None else NOT_GIVEN,
114                tool_choice=toolchoice if toolchoice is not None else NOT_GIVEN,
115            )
116            if issubclass(self.response_type, BaseModel):
117                # res.content should contain one ToolUseBlock
118                if len(res.content) != 1:
119                    raise ValueError("Expected one ToolUseBlock, got", len(res.content))
120                block = res.content[0]
121                if block.type != "tool_use":
122                    raise ValueError("Expected one ToolUseBlock, got", block.type)
123                return Response(self.response_type.model_validate(block.input))
124
125            elif self.response_type is str:
126                str_res = ""
127                for block in res.content:
128                    if block.type == "text":
129                        str_res += block.text
130                return Response(cast(R, str_res))
131
132            raise ValueError("Invalid response", res)
anthropic( model: Union[Literal['claude-3-7-sonnet-latest', 'claude-3-7-sonnet-20250219', 'claude-3-5-haiku-latest', 'claude-3-5-haiku-20241022', 'claude-sonnet-4-20250514', 'claude-sonnet-4-0', 'claude-4-sonnet-20250514', 'claude-3-5-sonnet-latest', 'claude-3-5-sonnet-20241022', 'claude-3-5-sonnet-20240620', 'claude-opus-4-0', 'claude-opus-4-20250514', 'claude-4-opus-20250514', 'claude-3-opus-latest', 'claude-3-opus-20240229', 'claude-3-sonnet-20240229', 'claude-3-haiku-20240307', 'claude-2.1', 'claude-2.0'], str], max_tokens: int, client: anthropic.AsyncAnthropic = <factory>, metadata: anthropic.types.metadata_param.MetadataParam | anthropic.NotGiven = NOT_GIVEN, stop_sequences: list[str] | anthropic.NotGiven = NOT_GIVEN, stream: Union[anthropic.NotGiven, Literal[False]] = NOT_GIVEN, system: str | Iterable[anthropic.types.text_block_param.TextBlockParam] | anthropic.NotGiven = NOT_GIVEN, temperature: float | anthropic.NotGiven = NOT_GIVEN, tool_choice: Union[anthropic.types.tool_choice_auto_param.ToolChoiceAutoParam, anthropic.types.tool_choice_any_param.ToolChoiceAnyParam, anthropic.types.tool_choice_tool_param.ToolChoiceToolParam, anthropic.types.tool_choice_none_param.ToolChoiceNoneParam, anthropic.NotGiven] = NOT_GIVEN, tools: Iterable[anthropic.types.tool_param.ToolParam] | anthropic.NotGiven = NOT_GIVEN, top_k: int | anthropic.NotGiven = NOT_GIVEN, top_p: float | anthropic.NotGiven = NOT_GIVEN, extra_headers: typing.Any | None = None, extra_query: typing.Any | None = None, extra_body: typing.Any | None = None, timeout: float | anthropic.Timeout | anthropic.NotGiven | None = NOT_GIVEN, *, response_type: type[R], delay: float = 0, retries: int = 0, default_response: Optional[R] = None, extra_kwargs: Mapping[str, typing.Any] = <factory>, system_message: str | None = None, messages: Sequence[ChatMessage] = <factory>)
max_tokens: int
model: Union[Literal['claude-3-7-sonnet-latest', 'claude-3-7-sonnet-20250219', 'claude-3-5-haiku-latest', 'claude-3-5-haiku-20241022', 'claude-sonnet-4-20250514', 'claude-sonnet-4-0', 'claude-4-sonnet-20250514', 'claude-3-5-sonnet-latest', 'claude-3-5-sonnet-20241022', 'claude-3-5-sonnet-20240620', 'claude-opus-4-0', 'claude-opus-4-20250514', 'claude-4-opus-20250514', 'claude-3-opus-latest', 'claude-3-opus-20240229', 'claude-3-sonnet-20240229', 'claude-3-haiku-20240307', 'claude-2.1', 'claude-2.0'], str]
client: anthropic.AsyncAnthropic
metadata: anthropic.types.metadata_param.MetadataParam | anthropic.NotGiven
stop_sequences: list[str] | anthropic.NotGiven
stream: Union[anthropic.NotGiven, Literal[False]]
system: str | Iterable[anthropic.types.text_block_param.TextBlockParam] | anthropic.NotGiven
temperature: float | anthropic.NotGiven
tool_choice: Union[anthropic.types.tool_choice_auto_param.ToolChoiceAutoParam, anthropic.types.tool_choice_any_param.ToolChoiceAnyParam, anthropic.types.tool_choice_tool_param.ToolChoiceToolParam, anthropic.types.tool_choice_none_param.ToolChoiceNoneParam, anthropic.NotGiven]
tools: Iterable[anthropic.types.tool_param.ToolParam] | anthropic.NotGiven
top_k: int | anthropic.NotGiven
top_p: float | anthropic.NotGiven
extra_headers: typing.Any | None
extra_query: typing.Any | None
extra_body: typing.Any | None
timeout: float | anthropic.Timeout | anthropic.NotGiven | None
@dataclass(slots=True)
class instructor(cbrkit.synthesis.providers.ChatProvider[InstructorPrompt, R], typing.Generic[R]):
19    @dataclass(slots=True)
20    class instructor[R: BaseModel](ChatProvider[InstructorPrompt, R]):
21        client: AsyncInstructor = field(repr=False)
22        strict: bool = True
23        context: dict[str, Any] | None = None
24
25        @override
26        async def __call_batch__(self, prompt: InstructorPrompt) -> Response[R]:
27            messages: list[ChatCompletionMessageParam] = []
28
29            if self.system_message is not None:
30                messages.append(
31                    {
32                        "role": "system",
33                        "content": self.system_message,
34                    }
35                )
36
37            messages.extend(cast(Sequence[ChatCompletionMessageParam], self.messages))
38
39            if isinstance(prompt, ChatPrompt):
40                messages.extend(
41                    cast(Sequence[ChatCompletionMessageParam], prompt.messages)
42                )
43
44            if messages and messages[-1]["role"] == "user":
45                messages.append(
46                    {
47                        "role": "assistant",
48                        "content": unpack_value(prompt),
49                    }
50                )
51            else:
52                messages.append(
53                    {
54                        "role": "user",
55                        "content": unpack_value(prompt),
56                    }
57                )
58
59            # retries are already handled by the base provider
60            return Response(
61                await self.client.chat.completions.create(
62                    model=self.model,
63                    messages=messages,
64                    response_model=self.response_type,
65                    context=self.context,
66                    **self.extra_kwargs,
67                )
68            )
instructor( client: instructor.client.AsyncInstructor, strict: bool = True, context: dict[str, typing.Any] | None = None, *, model: str, response_type: type[R], delay: float = 0, retries: int = 0, default_response: Optional[R] = None, extra_kwargs: Mapping[str, typing.Any] = <factory>, system_message: str | None = None, messages: Sequence[ChatMessage] = <factory>)
client: instructor.client.AsyncInstructor
strict: bool
context: dict[str, typing.Any] | None