FunctionModel is similar to TestModel,
but allows greater control over the model's behavior.
Its primary use case is for more advanced unit testing than is possible with TestModel.
Here's a minimal example:
function_model_usage.py
frompydantic_aiimportAgentfrompydantic_ai.messagesimportModelMessage,ModelResponse,TextPartfrompydantic_ai.models.functionimportFunctionModel,AgentInfomy_agent=Agent('openai:gpt-4o')asyncdefmodel_function(messages:list[ModelMessage],info:AgentInfo)->ModelResponse:print(messages)""" [ ModelRequest( parts=[ UserPromptPart( content='Testing my agent...', timestamp=datetime.datetime(...), part_kind='user-prompt', ) ], kind='request', ) ] """print(info)""" AgentInfo( function_tools=[], allow_text_result=True, result_tools=[], model_settings=None ) """returnModelResponse(parts=[TextPart('hello world')])asyncdeftest_my_agent():"""Unit test for my_agent, to be run by pytest."""withmy_agent.override(model=FunctionModel(model_function)):result=awaitmy_agent.run('Testing my agent...')assertresult.data=='hello world'
@dataclass(init=False)classFunctionModel(Model):"""A model controlled by a local function. Apart from `__init__`, all methods are private or match those of the base class. """function:FunctionDef|None=Nonestream_function:StreamFunctionDef|None=None_model_name:str=field(repr=False)_system:str=field(default='function',repr=False)@overloaddef__init__(self,function:FunctionDef,*,model_name:str|None=None)->None:...@overloaddef__init__(self,*,stream_function:StreamFunctionDef,model_name:str|None=None)->None:...@overloaddef__init__(self,function:FunctionDef,*,stream_function:StreamFunctionDef,model_name:str|None=None)->None:...def__init__(self,function:FunctionDef|None=None,*,stream_function:StreamFunctionDef|None=None,model_name:str|None=None,):"""Initialize a `FunctionModel`. Either `function` or `stream_function` must be provided, providing both is allowed. Args: function: The function to call for non-streamed requests. stream_function: The function to call for streamed requests. model_name: The name of the model. If not provided, a name is generated from the function names. """iffunctionisNoneandstream_functionisNone:raiseTypeError('Either `function` or `stream_function` must be provided')self.function=functionself.stream_function=stream_functionfunction_name=self.function.__name__ifself.functionisnotNoneelse''stream_function_name=self.stream_function.__name__ifself.stream_functionisnotNoneelse''self._model_name=model_nameorf'function:{function_name}:{stream_function_name}'asyncdefrequest(self,messages:list[ModelMessage],model_settings:ModelSettings|None,model_request_parameters:ModelRequestParameters,)->tuple[ModelResponse,usage.Usage]:agent_info=AgentInfo(model_request_parameters.function_tools,model_request_parameters.allow_text_result,model_request_parameters.result_tools,model_settings,)assertself.functionisnotNone,'FunctionModel must receive a `function` to support non-streamed requests'ifinspect.iscoroutinefunction(self.function):response=awaitself.function(messages,agent_info)else:response_=await_utils.run_in_executor(self.function,messages,agent_info)assertisinstance(response_,ModelResponse),response_response=response_response.model_name=self._model_name# TODO is `messages` right here? Should it just be new messages?returnresponse,_estimate_usage(chain(messages,[response]))@asynccontextmanagerasyncdefrequest_stream(self,messages:list[ModelMessage],model_settings:ModelSettings|None,model_request_parameters:ModelRequestParameters,)->AsyncIterator[StreamedResponse]:agent_info=AgentInfo(model_request_parameters.function_tools,model_request_parameters.allow_text_result,model_request_parameters.result_tools,model_settings,)assertself.stream_functionisnotNone,('FunctionModel must receive a `stream_function` to support streamed requests')response_stream=PeekableAsyncStream(self.stream_function(messages,agent_info))first=awaitresponse_stream.peek()ifisinstance(first,_utils.Unset):raiseValueError('Stream function must return at least one item')yieldFunctionStreamedResponse(_model_name=self._model_name,_iter=response_stream)@propertydefmodel_name(self)->str:"""The model name."""returnself._model_name@propertydefsystem(self)->str:"""The system / model provider."""returnself._system
The name of the model. If not provided, a name is generated from the function names.
None
Source code in pydantic_ai_slim/pydantic_ai/models/function.py
616263646566676869707172737475767778798081828384
def__init__(self,function:FunctionDef|None=None,*,stream_function:StreamFunctionDef|None=None,model_name:str|None=None,):"""Initialize a `FunctionModel`. Either `function` or `stream_function` must be provided, providing both is allowed. Args: function: The function to call for non-streamed requests. stream_function: The function to call for streamed requests. model_name: The name of the model. If not provided, a name is generated from the function names. """iffunctionisNoneandstream_functionisNone:raiseTypeError('Either `function` or `stream_function` must be provided')self.function=functionself.stream_function=stream_functionfunction_name=self.function.__name__ifself.functionisnotNoneelse''stream_function_name=self.stream_function.__name__ifself.stream_functionisnotNoneelse''self._model_name=model_nameorf'function:{function_name}:{stream_function_name}'
@dataclass(frozen=True)classAgentInfo:"""Information about an agent. This is passed as the second to functions used within [`FunctionModel`][pydantic_ai.models.function.FunctionModel]. """function_tools:list[ToolDefinition]"""The function tools available on this agent. These are the tools registered via the [`tool`][pydantic_ai.Agent.tool] and [`tool_plain`][pydantic_ai.Agent.tool_plain] decorators. """allow_text_result:bool"""Whether a plain text result is allowed."""result_tools:list[ToolDefinition]"""The tools that can called as the final result of the run."""model_settings:ModelSettings|None"""The model settings passed to the run call."""
Used to describe a chunk when streaming structured responses.
Source code in pydantic_ai_slim/pydantic_ai/models/function.py
169170171172173174175176177178179180181
@dataclassclassDeltaToolCall:"""Incremental change to a tool call. Used to describe a chunk when streaming structured responses. """name:str|None=None"""Incremental change to the name of the tool."""json_args:str|None=None"""Incremental change to the arguments as JSON"""tool_call_id:str|None=None"""Incremental change to the tool call ID."""
While this is defined as having return type of AsyncIterator[Union[str, DeltaToolCalls]], it should
really be considered as Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls],
E.g. you need to yield all text or all DeltaToolCalls, not mix them.
@dataclassclassFunctionStreamedResponse(StreamedResponse):"""Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""_model_name:str_iter:AsyncIterator[str|DeltaToolCalls]_timestamp:datetime=field(default_factory=_utils.now_utc)def__post_init__(self):self._usage+=_estimate_usage([])asyncdef_get_event_iterator(self)->AsyncIterator[ModelResponseStreamEvent]:asyncforiteminself._iter:ifisinstance(item,str):response_tokens=_estimate_string_tokens(item)self._usage+=usage.Usage(response_tokens=response_tokens,total_tokens=response_tokens)yieldself._parts_manager.handle_text_delta(vendor_part_id='content',content=item)else:delta_tool_calls=itemfordtc_index,delta_tool_callindelta_tool_calls.items():ifdelta_tool_call.json_args:response_tokens=_estimate_string_tokens(delta_tool_call.json_args)self._usage+=usage.Usage(response_tokens=response_tokens,total_tokens=response_tokens)maybe_event=self._parts_manager.handle_tool_call_delta(vendor_part_id=dtc_index,tool_name=delta_tool_call.name,args=delta_tool_call.json_args,tool_call_id=delta_tool_call.tool_call_id,)ifmaybe_eventisnotNone:yieldmaybe_event@propertydefmodel_name(self)->str:"""Get the model name of the response."""returnself._model_name@propertydeftimestamp(self)->datetime:"""Get the timestamp of the response."""returnself._timestamp