cbrkit.synthesis.providers
LLM provider integrations for synthesis.
Each provider wraps an LLM API and exposes a unified interface for use with
cbrkit.synthesis.build.
Providers are initialized with a model name and a response type (str for
plain text or a Pydantic model for structured output).
Additional options like temperature, seed, and max_tokens can be set.
Providers (each requires its respective extra and API key):
openai/openai_completions: OpenAI Completions API (OPENAI_API_KEY).openai_responses: OpenAI Responses API (OPENAI_API_KEY).openai_agents: OpenAI Agents framework (OPENAI_API_KEY).anthropic: Anthropic Claude API (ANTHROPIC_API_KEY).cohere: Cohere API (CO_API_KEY).google: Google Generative AI (GOOGLE_API_KEY).ollama: Ollama local inference (no API key needed).pydantic_ai: Pydantic AI framework.instructor: Instructor for structured output.
Wrappers:
pipe: Chains multiple providers sequentially.conversation: Manages multi-turn conversations with a provider.
Base Classes:
BaseProvider: Base class for synchronous providers.AsyncProvider: Base class for asynchronous providers.Response: Response model returned by providers.Usage: Token usage tracking.
Example:
>>> provider = openai( # doctest: +SKIP ... model="gpt-4o", ... response_type=str, ... temperature=0.7, ... )
1"""LLM provider integrations for synthesis. 2 3Each provider wraps an LLM API and exposes a unified interface for use with 4`cbrkit.synthesis.build`. 5Providers are initialized with a model name and a response type (`str` for 6plain text or a Pydantic model for structured output). 7Additional options like `temperature`, `seed`, and `max_tokens` can be set. 8 9Providers (each requires its respective extra and API key): 10- `openai` / `openai_completions`: OpenAI Completions API (`OPENAI_API_KEY`). 11- `openai_responses`: OpenAI Responses API (`OPENAI_API_KEY`). 12- `openai_agents`: OpenAI Agents framework (`OPENAI_API_KEY`). 13- `anthropic`: Anthropic Claude API (`ANTHROPIC_API_KEY`). 14- `cohere`: Cohere API (`CO_API_KEY`). 15- `google`: Google Generative AI (`GOOGLE_API_KEY`). 16- `ollama`: Ollama local inference (no API key needed). 17- `pydantic_ai`: Pydantic AI framework. 18- `instructor`: Instructor for structured output. 19 20Wrappers: 21- `pipe`: Chains multiple providers sequentially. 22- `conversation`: Manages multi-turn conversations with a provider. 23 24Base Classes: 25- `BaseProvider`: Base class for synchronous providers. 26- `AsyncProvider`: Base class for asynchronous providers. 27- `Response`: Response model returned by providers. 28- `Usage`: Token usage tracking. 29 30Example: 31 >>> provider = openai( # doctest: +SKIP 32 ... model="gpt-4o", 33 ... response_type=str, 34 ... temperature=0.7, 35 ... ) 36""" 37 38from ...helpers import optional_dependencies 39from .model import AsyncProvider, BaseProvider, Response, Usage 40from .wrappers import conversation, pipe 41 42with optional_dependencies(): 43 from .openai_completions import openai_completions 44 45 openai = openai_completions 46with optional_dependencies(): 47 from .openai_responses import openai_responses 48with optional_dependencies(): 49 from .anthropic import anthropic 50with optional_dependencies(): 51 from .cohere import cohere 52with optional_dependencies(): 53 from .ollama import ollama 54with optional_dependencies(): 55 from .instructor import instructor 56with optional_dependencies(): 57 from .google import google 58with optional_dependencies(): 59 from .pydantic_ai import pydantic_ai 60with optional_dependencies(): 61 from .openai_agents import openai_agents 62 63__all__ = [ 64 "AsyncProvider", 65 "BaseProvider", 66 "Response", 67 "Usage", 68 "pipe", 69 "conversation", 70 "anthropic", 71 "cohere", 72 "google", 73 "instructor", 74 "ollama", 75 "openai", 76 "openai_agents", 77 "openai_completions", 78 "openai_responses", 79 "pydantic_ai", 80]
34@dataclass(slots=True, kw_only=True) 35class AsyncProvider[P, R](BatchConversionFunc[P, R], ABC): 36 """Base class for async batch-processing providers.""" 37 38 def __call__(self, batches: Sequence[P]) -> Sequence[R]: 39 return run_coroutine(self.__call_batches__(batches)) 40 41 async def __call_batches__(self, batches: Sequence[P]) -> Sequence[R]: 42 logger.info(f"Processing {len(batches)} batches") 43 44 return await asyncio.gather( 45 *( 46 self.__call_batch_wrapper__(batch, idx) 47 for idx, batch in enumerate(batches) 48 ) 49 ) 50 51 async def __call_batch_wrapper__(self, prompt: P, idx: int) -> R: 52 result = await self.__call_batch__(prompt) 53 logger.debug(f"Result of batch {idx + 1}: {result}") 54 return result 55 56 @abstractmethod 57 async def __call_batch__(self, prompt: P) -> R: ...
Base class for async batch-processing providers.
60@dataclass(slots=True, kw_only=True) 61class BaseProvider[P, R](AsyncProvider[P, Response[R]], ABC): 62 """Base provider with model configuration, retry logic, and error handling.""" 63 64 model: str 65 response_type: type[R] 66 default_response: R | None = None 67 system_message: str | None = None 68 delay: float = 0 69 retries: int = 0 70 extra_kwargs: Mapping[str, Any] = field(default_factory=dict) 71 72 @override 73 async def __call_batch_wrapper__( 74 self, prompt: P, idx: int, retry: int = 0 75 ) -> Response[R]: 76 if self.delay > 0 and retry == 0: 77 await asyncio.sleep(idx * self.delay) 78 79 try: 80 return await super(BaseProvider, self).__call_batch_wrapper__(prompt, idx) 81 82 except Exception as e: 83 if retry < self.retries: 84 logger.info(f"Retrying batch {idx + 1}...") 85 return await self.__call_batch_wrapper__(prompt, idx, retry + 1) 86 87 if self.default_response is not None: 88 logger.error(f"Error processing batch {idx + 1}: {e}") 89 return Response(self.default_response, Usage(0, 0)) 90 91 raise e
Base provider with model configuration, retry logic, and error handling.
27@dataclass(slots=True, frozen=True) 28class Response[T](StructuredValue[T]): 29 """Provider response wrapping a value with usage statistics.""" 30 31 usage: Usage = field(default_factory=Usage)
Provider response wrapping a value with usage statistics.
Inherited Members
14@dataclass(slots=True, frozen=True) 15class Usage: 16 """Token usage statistics for a provider response.""" 17 18 prompt_tokens: int = 0 19 completion_tokens: int = 0 20 21 @property 22 def total_tokens(self) -> int: 23 """Return the sum of prompt and completion tokens.""" 24 return self.prompt_tokens + self.completion_tokens
Token usage statistics for a provider response.
35@dataclass(slots=True, frozen=True) 36class pipe[P, R](BatchConversionFunc[P, R]): 37 """Chains multiple generation functions, converting output back to input between steps.""" 38 39 generation_funcs: MaybeSequence[AnyConversionFunc[P, R]] 40 conversion_func: ConversionFunc[R, P] 41 42 def __call__(self, batches: Sequence[P]) -> Sequence[R]: 43 funcs = produce_sequence(self.generation_funcs) 44 current_input = batches 45 current_output: Sequence[R] = [] 46 47 for func in funcs: 48 batch_func = batchify_conversion(func) 49 current_output = batch_func(current_input) 50 current_input = [self.conversion_func(output) for output in current_output] 51 52 if not len(current_output) == len(batches): 53 raise ValueError( 54 "The number of outputs does not match the number of inputs, " 55 "did you provie a generation function?" 56 ) 57 58 return current_output
Chains multiple generation functions, converting output back to input between steps.
18@dataclass(slots=True, frozen=True) 19class conversation[P, R](ConversionFunc[Sequence[P], R]): 20 """Iteratively generates responses until the conversion function returns None.""" 21 22 generation_func: AnyConversionFunc[Sequence[P], R] 23 conversion_func: ConversionFunc[R, Sequence[P] | None] 24 25 def __call__(self, batch: Sequence[P]) -> R: 26 func = unbatchify_conversion(self.generation_func) 27 result = func(batch) 28 29 while next_batch := self.conversion_func(result): 30 result = func(next_batch) 31 32 return result
Iteratively generates responses until the conversion function returns None.
29 @dataclass(slots=True) 30 class anthropic[R: str | BaseModel](BaseProvider[AnthropicPrompt, R]): 31 """Provider that calls Anthropic's messages API.""" 32 33 model: ModelParam 34 max_tokens: int 35 messages: Sequence[BetaMessageParam] = field(default_factory=tuple) 36 client: AsyncAnthropic = field(default_factory=AsyncAnthropic, repr=False) 37 metadata: MetadataParam | None = None 38 stop_sequences: list[str] | None = None 39 system: str | Iterable[TextBlockParam] | None = None 40 temperature: float | None = None 41 tool_choice: ToolChoiceParam | None = None 42 tools: Iterable[ToolParam] | None = None 43 top_k: int | None = None 44 top_p: float | None = None 45 extra_headers: Any | None = None 46 extra_query: Any | None = None 47 extra_body: Any | None = None 48 timeout: float | Timeout | None = None 49 50 @override 51 async def __call_batch__(self, prompt: AnthropicPrompt) -> Response[R]: 52 messages: list[BetaMessageParam] = [] 53 54 if self.system_message is not None: 55 # anthropic does not have a system/developer role 56 messages.append({"role": "user", "content": self.system_message}) 57 58 messages.extend(self.messages) 59 60 if isinstance(prompt, str): 61 messages.append({"role": "user", "content": prompt}) 62 else: 63 messages.extend(prompt) 64 65 res = await self.client.beta.messages.parse( 66 model=self.model, 67 messages=messages, 68 max_tokens=self.max_tokens, 69 output_format=self.response_type # type: ignore[arg-type] 70 if issubclass(self.response_type, BaseModel) 71 else omit, 72 ) 73 74 usage = Usage( 75 res.usage.input_tokens, 76 res.usage.output_tokens, 77 ) 78 79 if ( 80 isinstance(self.response_type, type) 81 and issubclass(self.response_type, BaseModel) 82 and (parsed := res.parsed_output) is not None 83 ): 84 return Response(parsed, usage) 85 86 if ( 87 isinstance(self.response_type, type) 88 and issubclass(self.response_type, str) 89 and len(res.content) > 0 90 ): 91 aggregated_content = "".join( 92 getattr(block, "text", "") for block in res.content 93 ) 94 return Response(cast(R, aggregated_content), usage) 95 96 raise ValueError("Invalid response", res)
Provider that calls Anthropic's messages API.
33 @dataclass(slots=True) 34 class cohere[R: str | BaseModel](BaseProvider[CoherePrompt, R]): 35 """Provider that calls Cohere's chat API.""" 36 37 messages: Sequence[ChatMessageV2] = field(default_factory=tuple) 38 documents: Sequence[V2ChatRequestDocumentsItem] = field(default_factory=tuple) 39 client: AsyncClient = field(default_factory=AsyncClient, repr=False) 40 request_options: RequestOptions | None = None 41 citation_options: CitationOptions | None = None 42 safety_mode: V2ChatRequestSafetyMode | None = None 43 max_tokens: int | None = None 44 stop_sequences: Sequence[str] | None = None 45 temperature: float | None = None 46 seed: int | None = None 47 frequency_penalty: float | None = None 48 presence_penalty: float | None = None 49 k: int | None = None 50 p: float | None = None 51 logprobs: bool | None = None 52 53 @override 54 async def __call_batch__(self, prompt: CoherePrompt) -> Response[R]: 55 documents: list[V2ChatRequestDocumentsItem] = list(self.documents) 56 57 if isinstance(prompt, CohereDocumentsPrompt): 58 documents.extend(prompt.documents) 59 60 if issubclass(self.response_type, BaseModel) and documents: 61 raise ValueError( 62 "Structured output format is not supported when using documents" 63 ) 64 65 messages: list[ChatMessageV2] = [] 66 67 if self.system_message is not None: 68 messages.append(SystemChatMessageV2(content=self.system_message)) 69 70 if isinstance(prompt, str): 71 messages.append(UserChatMessageV2(content=prompt)) 72 elif isinstance(prompt, CohereDocumentsPrompt): 73 messages.extend(prompt.messages) 74 else: 75 messages.extend(prompt) 76 77 res = await self.client.v2.chat( 78 model=self.model, 79 messages=messages, 80 request_options=self.request_options, 81 documents=documents if documents else None, 82 response_format=JsonObjectResponseFormatV2( 83 json_schema=self.response_type.model_json_schema() 84 ) 85 if issubclass(self.response_type, BaseModel) 86 else None, 87 citation_options=self.citation_options, 88 safety_mode=self.safety_mode, 89 max_tokens=self.max_tokens, 90 stop_sequences=self.stop_sequences, 91 temperature=self.temperature, 92 seed=self.seed, 93 frequency_penalty=self.frequency_penalty, 94 presence_penalty=self.presence_penalty, 95 k=self.k, 96 p=self.p, 97 logprobs=self.logprobs, 98 **self.extra_kwargs, 99 ) 100 101 content = res.message.content 102 103 if content is None: 104 raise ValueError("The completion is empty") 105 106 if issubclass(self.response_type, BaseModel): 107 if len(content) != 1 or content[0].type != "text": 108 raise ValueError( 109 "The completion is empty, has multiple outputs, or is not text" 110 ) 111 112 return Response(self.response_type.model_validate_json(content[0].text)) 113 114 aggregated_content = "".join( 115 block.text for block in content if block.type == "text" 116 ) 117 118 return Response(cast(R, aggregated_content))
Provider that calls Cohere's chat API.
20 @dataclass(slots=True) 21 class google[R: BaseModel | str](BaseProvider[GooglePrompt, R]): 22 """Provider that calls Google's Generative AI API.""" 23 24 client: Client = field(default_factory=Client, repr=False) 25 config: GenerateContentConfig = field(init=False) 26 base_config: InitVar[GenerateContentConfig | None] = None 27 28 def __post_init__(self, base_config: GenerateContentConfig | None) -> None: 29 self.config = base_config or GenerateContentConfig() 30 31 if issubclass(self.response_type, BaseModel): 32 self.config.response_schema = self.response_type 33 34 if self.system_message is not None: 35 self.config.system_instruction = self.system_message 36 37 @override 38 async def __call_batch__(self, prompt: GooglePrompt) -> Response[R]: 39 res = await self.client.aio.models.generate_content( 40 model=self.model, 41 contents=prompt, 42 config=self.config, 43 **self.extra_kwargs, 44 ) 45 46 if ( 47 issubclass(self.response_type, BaseModel) 48 and (parsed := res.parsed) 49 and isinstance(parsed, self.response_type) 50 ): 51 return Response(cast(R, parsed)) 52 53 elif issubclass(self.response_type, str) and (text := res.text): 54 return Response(cast(R, text)) 55 56 raise ValueError("Invalid response", res)
Provider that calls Google's Generative AI API.
19 @dataclass(slots=True) 20 class instructor[R: BaseModel](BaseProvider[InstructorPrompt, R]): 21 """Provider that uses the instructor library for structured outputs.""" 22 23 client: AsyncInstructor = field(repr=False) 24 messages: Sequence[ChatCompletionMessageParam] = field(default_factory=tuple) 25 strict: bool = True 26 context: dict[str, Any] | None = None 27 28 @override 29 async def __call_batch__(self, prompt: InstructorPrompt) -> Response[R]: 30 messages: list[ChatCompletionMessageParam] = [] 31 32 if self.system_message is not None: 33 messages.append({"role": "system", "content": self.system_message}) 34 35 messages.extend(self.messages) 36 37 if isinstance(prompt, str): 38 messages.append({"role": "user", "content": prompt}) 39 else: 40 messages.extend(prompt) 41 42 # retries are already handled by the base provider 43 return Response( 44 await self.client.chat.completions.create( 45 model=self.model, 46 messages=messages, 47 response_model=self.response_type, 48 context=self.context, 49 **self.extra_kwargs, 50 ) 51 )
Provider that uses the instructor library for structured outputs.
17 @dataclass(slots=True) 18 class ollama[R: str | BaseModel](BaseProvider[OllamaPrompt, R]): 19 """Provider that calls Ollama's chat API.""" 20 21 client: AsyncClient = field(default_factory=AsyncClient, repr=False) 22 messages: Sequence[Message] = field(default_factory=tuple) 23 options: Options | None = None 24 keep_alive: float | str | None = None 25 26 @override 27 async def __call_batch__(self, prompt: OllamaPrompt) -> Response[R]: 28 messages: list[Message] = [] 29 30 if self.system_message is not None: 31 messages.append(Message(role="system", content=self.system_message)) 32 33 messages.extend(self.messages) 34 35 if isinstance(prompt, str): 36 messages.append(Message(role="user", content=prompt)) 37 else: 38 messages.extend(prompt) 39 40 res = await self.client.chat( 41 model=self.model, 42 messages=messages, 43 options=self.options, 44 keep_alive=self.keep_alive, 45 format=self.response_type.model_json_schema() 46 if issubclass(self.response_type, BaseModel) 47 else None, 48 **self.extra_kwargs, 49 ) 50 51 content = res["message"]["content"] 52 53 if self.response_type is str: 54 return Response(content) 55 56 return Response(json.loads(content))
Provider that calls Ollama's chat API.
33 @dataclass(slots=True) 34 class openai_agents[T, R](AsyncProvider[OpenaiAgentsPrompt, TypedRunResult[R]]): 35 """Provider that runs OpenAI Agents SDK agents.""" 36 37 agents: MaybeSequence[Agent[T]] 38 context: T | None = None 39 max_turns: int = DEFAULT_MAX_TURNS 40 hooks: RunHooks[T] | None = None 41 run_config: RunConfig | None = None 42 43 @override 44 async def __call_batch__(self, prompt: OpenaiAgentsPrompt) -> TypedRunResult[R]: 45 agents = produce_sequence(self.agents) 46 47 if not agents: 48 raise ValueError("No agents given.") 49 50 head_agent, *tail_agents = agents 51 52 session: Any = SQLiteSession(uuid1().hex) if len(agents) > 1 else None 53 54 run = partial( 55 Runner.run, 56 context=self.context, 57 max_turns=self.max_turns, 58 hooks=self.hooks, 59 run_config=self.run_config, 60 session=session, 61 ) 62 63 response: RunResult = await run(head_agent, prompt) 64 65 for agent in tail_agents: 66 response = await run(agent, []) 67 68 return cast(TypedRunResult[R], response)
Provider that runs OpenAI Agents SDK agents.
30 @dataclass(slots=True) 31 class openai_completions[R: BaseModel | str](BaseProvider[OpenAiPrompt, R]): 32 """Provider that calls OpenAI's chat completions API.""" 33 34 model: str | ChatModel 35 messages: Sequence[ChatCompletionMessageParam] = field(default_factory=tuple) 36 tool_choice: type[BaseModel] | str | None = None 37 client: AsyncOpenAI = field(default_factory=AsyncOpenAI, repr=False) 38 frequency_penalty: float | None = None 39 logit_bias: dict[str, int] | None = None 40 logprobs: bool | None = None 41 max_completion_tokens: int | None = None 42 metadata: dict[str, str] | None = None 43 n: int | None = None 44 presence_penalty: float | None = None 45 seed: int | None = None 46 stop: str | list[str] | None = None 47 store: bool | None = None 48 reasoning_effort: Literal["low", "medium", "high"] | None = None 49 temperature: float | None = None 50 top_logprobs: int | None = None 51 top_p: float | None = None 52 extra_headers: Any | None = None 53 extra_query: Any | None = None 54 extra_body: Any | None = None 55 timeout: float | Timeout | None = None 56 57 @override 58 async def __call_batch__(self, prompt: OpenAiPrompt) -> Response[R]: 59 messages: list[ChatCompletionMessageParam] = [] 60 61 if self.system_message is not None: 62 messages.append({"role": "system", "content": self.system_message}) 63 64 messages.extend(self.messages) 65 66 if isinstance(prompt, str): 67 messages.append({"role": "user", "content": prompt}) 68 else: 69 messages.extend(prompt) 70 71 tools: list[ChatCompletionToolParam] | None = None 72 tool_choice: ChatCompletionNamedToolChoiceParam | None = None 73 response_type_origin = get_origin(self.response_type) 74 75 if response_type_origin is UnionType or response_type_origin is Union: 76 tools = [ 77 pydantic_function_tool(tool) 78 for tool in get_args(self.response_type) 79 if issubclass(tool, BaseModel) 80 ] 81 elif ( 82 issubclass(self.response_type, BaseModel) 83 and self.tool_choice is not None 84 ): 85 tools = [pydantic_function_tool(self.response_type)] 86 87 if self.tool_choice is not None: 88 tool_choice = { 89 "type": "function", 90 "function": { 91 "name": self.tool_choice 92 if isinstance(self.tool_choice, str) 93 else self.tool_choice.__name__, 94 }, 95 } 96 97 try: 98 res = await self.client.beta.chat.completions.parse( 99 model=self.model, 100 messages=messages, 101 response_format=self.response_type # type: ignore[arg-type] 102 if tools is None and issubclass(self.response_type, BaseModel) 103 else omit, 104 tools=if_given(tools), 105 tool_choice=if_given(tool_choice), 106 frequency_penalty=if_given(self.frequency_penalty), 107 logit_bias=if_given(self.logit_bias), 108 logprobs=if_given(self.logprobs), 109 max_completion_tokens=if_given(self.max_completion_tokens), 110 metadata=if_given(self.metadata), 111 n=if_given(self.n), 112 presence_penalty=if_given(self.presence_penalty), 113 seed=if_given(self.seed), 114 stop=if_given(self.stop), 115 store=if_given(self.store), 116 reasoning_effort=if_given(self.reasoning_effort), 117 temperature=if_given(self.temperature), 118 top_logprobs=if_given(self.top_logprobs), 119 top_p=if_given(self.top_p), 120 extra_headers=self.extra_headers, 121 extra_query=self.extra_query, 122 extra_body=self.extra_body, 123 timeout=self.timeout, 124 **self.extra_kwargs, 125 ) 126 except ValidationError as e: 127 for error in e.errors(): 128 logger.error(f"Invalid response ({error['msg']}): {error['input']}") 129 raise 130 131 choice = res.choices[0] 132 message = choice.message 133 134 assert res.usage is not None 135 usage = Usage(res.usage.prompt_tokens, res.usage.completion_tokens) 136 137 if choice.finish_reason == "length": 138 raise ValueError("Length limit", res) 139 140 if choice.finish_reason == "content_filter": 141 raise ValueError("Content filter", res) 142 143 if message.refusal: 144 raise ValueError("Refusal", res) 145 146 if ( 147 isinstance(self.response_type, type) 148 and issubclass(self.response_type, BaseModel) 149 and (parsed := message.parsed) is not None 150 ): 151 return Response(cast(R, parsed), usage) 152 153 if ( 154 isinstance(self.response_type, type) 155 and issubclass(self.response_type, str) 156 and (content := message.content) is not None 157 ): 158 return Response(cast(R, content), usage) 159 160 if ( 161 tools is not None 162 and (tool_calls := message.tool_calls) is not None 163 and (parsed := tool_calls[0].function.parsed_arguments) is not None 164 ): 165 return Response(cast(R, parsed), usage) 166 167 raise ValueError("Invalid response", res)
Provider that calls OpenAI's chat completions API.
33 @dataclass(slots=True) 34 class openai_responses[R: BaseModel | str](BaseProvider[OpenAiResponsesPrompt, R]): 35 """Provider that calls the OpenAI Responses API and parses structured outputs.""" 36 37 input_items: Sequence[ResponseInputItemParam] = field(default_factory=tuple) 38 tool_choice: type[BaseModel] | str | None = None 39 client: AsyncOpenAI = field(default_factory=AsyncOpenAI, repr=False) 40 include: Sequence[ResponseIncludable] | None = None 41 max_output_tokens: int | None = None 42 max_tool_calls: int | None = None 43 metadata: dict[str, str] | None = None 44 parallel_tool_calls: bool | None = None 45 store: bool | None = None 46 temperature: float | None = None 47 top_logprobs: int | None = None 48 top_p: float | None = None 49 text: ResponseTextConfigParam | None = None 50 extra_headers: Any | None = None 51 extra_query: Any | None = None 52 extra_body: Any | None = None 53 timeout: float | Timeout | None = None 54 55 @override 56 async def __call_batch__(self, prompt: OpenAiResponsesPrompt) -> Response[R]: 57 inputs: list[ResponseInputItemParam] = [] 58 59 if self.system_message is not None: 60 inputs.append({"role": "system", "content": self.system_message}) 61 62 inputs.extend(self.input_items) 63 64 if isinstance(prompt, str): 65 inputs.append({"role": "user", "content": prompt}) 66 else: 67 inputs.extend(prompt) 68 69 tools: list[ParseableToolParam] | None = None 70 tool_choice: ToolChoice | None = None 71 text_format: type[BaseModel] | Omit = omit 72 73 response_type_origin = get_origin(self.response_type) 74 75 if response_type_origin is UnionType or response_type_origin is Union: 76 tools = [ 77 cast(ParseableToolParam, pydantic_function_tool(tool)) 78 for tool in get_args(self.response_type) 79 if issubclass(tool, BaseModel) 80 ] 81 elif issubclass(self.response_type, BaseModel): 82 if self.tool_choice is not None: 83 tools = [ 84 cast( 85 ParseableToolParam, 86 pydantic_function_tool(self.response_type), 87 ) 88 ] 89 else: 90 text_format = self.response_type 91 92 if self.tool_choice is not None: 93 tool_choice = ToolChoiceFunctionParam( 94 name=self.tool_choice 95 if isinstance(self.tool_choice, str) 96 else self.response_type.__name__, 97 type="function", 98 ) 99 100 text_param: ResponseTextConfigParam | Omit 101 102 if self.text is None: 103 text_param = omit 104 elif text_format is not omit and "format" in self.text: 105 raise ValueError( 106 "`text.format` cannot be set when using structured outputs." 107 ) 108 else: 109 text_param = self.text 110 111 try: 112 res = await self.client.responses.parse( 113 model=self.model, 114 input=inputs, 115 instructions=if_given(self.system_message), 116 include=if_given( 117 list(self.include) if self.include is not None else None 118 ), 119 tools=if_given(tools), 120 tool_choice=if_given(tool_choice), 121 max_output_tokens=if_given(self.max_output_tokens), 122 max_tool_calls=if_given(self.max_tool_calls), 123 metadata=if_given(self.metadata), 124 parallel_tool_calls=if_given(self.parallel_tool_calls), 125 store=if_given(self.store), 126 temperature=if_given(self.temperature), 127 top_logprobs=if_given(self.top_logprobs), 128 top_p=if_given(self.top_p), 129 text=text_param, 130 text_format=text_format, # type: ignore[arg-type] 131 extra_headers=self.extra_headers, 132 extra_query=self.extra_query, 133 extra_body=self.extra_body, 134 timeout=self.timeout, 135 **self.extra_kwargs, 136 ) 137 except ValidationError as e: 138 for error in e.errors(): 139 logger.error(f"Invalid response ({error['msg']}): {error['input']}") 140 raise 141 142 if res.incomplete_details is not None: 143 raise ValueError( 144 res.incomplete_details.reason or "Response incomplete", res 145 ) 146 147 for output in res.output: 148 content_list = getattr(output, "content", None) 149 if content_list is not None: 150 for content in content_list: 151 if content.type == "refusal": 152 raise ValueError("Refusal", res) 153 154 assert res.usage is not None 155 usage = Usage(res.usage.input_tokens, res.usage.output_tokens) 156 157 if tools is not None: 158 for output in res.output: 159 if output.type == "function_call": 160 parsed_arguments = getattr(output, "parsed_arguments", None) 161 162 if parsed_arguments is not None: 163 return Response(cast(R, parsed_arguments), usage) 164 165 raise ValueError("Invalid response", res) 166 167 if text_format is not omit and (parsed := res.output_parsed) is not None: 168 return Response(cast(R, parsed), usage) 169 170 if issubclass(self.response_type, str): 171 content = res.output_text 172 173 if content: 174 return Response(cast(R, content), usage) 175 176 raise ValueError("Invalid response", res)
Provider that calls the OpenAI Responses API and parses structured outputs.
23 @dataclass(slots=True) 24 class pydantic_ai[T, R](AsyncProvider[PydanticAiPrompt, AgentRunResult[R]]): 25 """Provider that runs pydantic-ai agents.""" 26 27 agents: MaybeSequence[Agent[T, R]] 28 deps: T 29 30 @override 31 async def __call_batch__(self, prompt: PydanticAiPrompt) -> AgentRunResult[R]: 32 agents = produce_sequence(self.agents) 33 34 user_prompt: str | Sequence[UserContent] | None = None 35 message_history: Sequence[ModelMessage] | None = None 36 37 if isinstance(prompt, str): 38 user_prompt = prompt 39 elif all(isinstance(msg, (ModelRequest, ModelResponse)) for msg in prompt): 40 message_history = cast(Sequence[ModelMessage], prompt) 41 else: 42 user_prompt = cast(Sequence[UserContent], prompt) 43 44 response: AgentRunResult[R] | None = None 45 46 for agent in agents: 47 response = await agent.run( 48 user_prompt=user_prompt, 49 deps=self.deps, 50 message_history=message_history, 51 ) 52 message_history = response.all_messages() 53 54 if not response: 55 raise ValueError("No agents given.") 56 57 return response
Provider that runs pydantic-ai agents.