examples : make pydantic scripts pass mypy and support py3.8 (#5099)

This commit is contained in:
Jared Van Bortel 2024-01-25 14:51:24 -05:00 committed by GitHub
parent 256d1bb0dd
commit d292f4f204
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 121 deletions

View File

@ -1,14 +1,14 @@
# Function calling example using pydantic models.
import datetime
import importlib
import json
from enum import Enum
from typing import Union, Optional
from typing import Optional, Union
import requests
from pydantic import BaseModel, Field
import importlib
from pydantic_models_to_grammar import generate_gbnf_grammar_and_documentation, convert_dictionary_to_pydantic_model, add_run_method_to_dynamic_model, create_dynamic_model_from_function
from pydantic_models_to_grammar import (add_run_method_to_dynamic_model, convert_dictionary_to_pydantic_model,
create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation)
# Function to get completion on the llama.cpp server with grammar.
@ -35,7 +35,7 @@ class SendMessageToUser(BaseModel):
print(self.message)
# Enum for the calculator function.
# Enum for the calculator tool.
class MathOperation(Enum):
ADD = "add"
SUBTRACT = "subtract"
@ -43,7 +43,7 @@ class MathOperation(Enum):
DIVIDE = "divide"
# Very simple calculator tool for the agent.
# Simple pydantic calculator tool for the agent that can add, subtract, multiply, and divide. Docstring and description of fields will be used in system prompt.
class Calculator(BaseModel):
"""
Perform a math operation on two numbers.
@ -148,37 +148,6 @@ def get_current_datetime(output_format: Optional[str] = None):
return datetime.datetime.now().strftime(output_format)
# Enum for the calculator tool.
class MathOperation(Enum):
ADD = "add"
SUBTRACT = "subtract"
MULTIPLY = "multiply"
DIVIDE = "divide"
# Simple pydantic calculator tool for the agent that can add, subtract, multiply, and divide. Docstring and description of fields will be used in system prompt.
class Calculator(BaseModel):
"""
Perform a math operation on two numbers.
"""
number_one: Union[int, float] = Field(..., description="First number.")
operation: MathOperation = Field(..., description="Math operation to perform.")
number_two: Union[int, float] = Field(..., description="Second number.")
def run(self):
if self.operation == MathOperation.ADD:
return self.number_one + self.number_two
elif self.operation == MathOperation.SUBTRACT:
return self.number_one - self.number_two
elif self.operation == MathOperation.MULTIPLY:
return self.number_one * self.number_two
elif self.operation == MathOperation.DIVIDE:
return self.number_one / self.number_two
else:
raise ValueError("Unknown operation.")
# Example function to get the weather
def get_current_weather(location, unit):
"""Get the current weather in a given location"""

View File

@ -1,15 +1,21 @@
from __future__ import annotations
import inspect
import json
import re
from copy import copy
from inspect import isclass, getdoc
from types import NoneType
from enum import Enum
from inspect import getdoc, isclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
from docstring_parser import parse
from pydantic import BaseModel, create_model, Field
from typing import Any, Type, List, get_args, get_origin, Tuple, Union, Optional, _GenericAlias
from enum import Enum
from typing import get_type_hints, Callable
import re
from pydantic import BaseModel, Field, create_model
if TYPE_CHECKING:
from types import GenericAlias
else:
# python 3.8 compat
from typing import _GenericAlias as GenericAlias
class PydanticDataType(Enum):
@ -43,7 +49,7 @@ class PydanticDataType(Enum):
SET = "set"
def map_pydantic_type_to_gbnf(pydantic_type: Type[Any]) -> str:
def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str:
if isclass(pydantic_type) and issubclass(pydantic_type, str):
return PydanticDataType.STRING.value
elif isclass(pydantic_type) and issubclass(pydantic_type, bool):
@ -57,22 +63,22 @@ def map_pydantic_type_to_gbnf(pydantic_type: Type[Any]) -> str:
elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel):
return format_model_and_field_name(pydantic_type.__name__)
elif get_origin(pydantic_type) == list:
elif get_origin(pydantic_type) is list:
element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-list"
elif get_origin(pydantic_type) == set:
elif get_origin(pydantic_type) is set:
element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-set"
elif get_origin(pydantic_type) == Union:
elif get_origin(pydantic_type) is Union:
union_types = get_args(pydantic_type)
union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
return f"union-{'-or-'.join(union_rules)}"
elif get_origin(pydantic_type) == Optional:
elif get_origin(pydantic_type) is Optional:
element_type = get_args(pydantic_type)[0]
return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
elif isclass(pydantic_type):
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}"
elif get_origin(pydantic_type) == dict:
elif get_origin(pydantic_type) is dict:
key_type, value_type = get_args(pydantic_type)
return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
else:
@ -106,7 +112,6 @@ def get_members_structure(cls, rule_name):
return f"{cls.__name__.lower()} ::= " + " | ".join(members)
if cls.__annotations__ and cls.__annotations__ != {}:
result = f'{rule_name} ::= "{{"'
type_list_rules = []
# Modify this comprehension
members = [
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
@ -116,27 +121,25 @@ def get_members_structure(cls, rule_name):
result += '"," '.join(members)
result += ' "}"'
return result, type_list_rules
elif rule_name == "custom-class-any":
return result
if rule_name == "custom-class-any":
result = f"{rule_name} ::= "
result += "value"
type_list_rules = []
return result, type_list_rules
else:
init_signature = inspect.signature(cls.__init__)
parameters = init_signature.parameters
result = f'{rule_name} ::= "{{"'
type_list_rules = []
# Modify this comprehension too
members = [
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param.annotation)}'
for name, param in parameters.items()
if name != "self" and param.annotation != inspect.Parameter.empty
]
return result
result += '", "'.join(members)
result += ' "}"'
return result, type_list_rules
init_signature = inspect.signature(cls.__init__)
parameters = init_signature.parameters
result = f'{rule_name} ::= "{{"'
# Modify this comprehension too
members = [
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param.annotation)}'
for name, param in parameters.items()
if name != "self" and param.annotation != inspect.Parameter.empty
]
result += '", "'.join(members)
result += ' "}"'
return result
def regex_to_gbnf(regex_pattern: str) -> str:
@ -269,7 +272,7 @@ def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None
def generate_gbnf_rule_for_type(
model_name, field_name, field_type, is_optional, processed_models, created_rules, field_info=None
) -> Tuple[str, list]:
) -> tuple[str, list[str]]:
"""
Generate GBNF rule for a given field type.
@ -283,7 +286,7 @@ def generate_gbnf_rule_for_type(
:param field_info: Additional information about the field (optional).
:return: Tuple containing the GBNF type and a list of additional rules.
:rtype: Tuple[str, list]
:rtype: tuple[str, list]
"""
rules = []
@ -321,8 +324,7 @@ def generate_gbnf_rule_for_type(
gbnf_type, rules = model_name + "-" + field_name, rules
elif gbnf_type.startswith("custom-class-"):
nested_model_rules, field_types = get_members_structure(field_type, gbnf_type)
rules.append(nested_model_rules)
rules.append(get_members_structure(field_type, gbnf_type))
elif gbnf_type.startswith("custom-dict-"):
key_type, value_type = get_args(field_type)
@ -341,14 +343,14 @@ def generate_gbnf_rule_for_type(
union_rules = []
for union_type in union_types:
if isinstance(union_type, _GenericAlias):
if isinstance(union_type, GenericAlias):
union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type(
model_name, field_name, union_type, False, processed_models, created_rules
)
union_rules.append(union_gbnf_type)
rules.extend(union_rules_list)
elif not issubclass(union_type, NoneType):
elif not issubclass(union_type, type(None)):
union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type(
model_name, field_name, union_type, False, processed_models, created_rules
)
@ -424,14 +426,10 @@ def generate_gbnf_rule_for_type(
else:
gbnf_type, rules = gbnf_type, []
if gbnf_type not in created_rules:
return gbnf_type, rules
else:
if gbnf_type in created_rules:
return gbnf_type, rules
return gbnf_type, rules
def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created_rules: dict) -> (list, bool, bool):
def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[BaseModel]], created_rules: dict[str, list[str]]) -> tuple[list[str], bool]:
"""
Generate GBnF Grammar
@ -452,7 +450,7 @@ def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created
```
"""
if model in processed_models:
return []
return [], False
processed_models.add(model)
model_name = format_model_and_field_name(model.__name__)
@ -518,7 +516,7 @@ def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created
def generate_gbnf_grammar_from_pydantic_models(
models: List[Type[BaseModel]], outer_object_name: str = None, outer_object_content: str = None,
models: list[type[BaseModel]], outer_object_name: str | None = None, outer_object_content: str | None = None,
list_of_outputs: bool = False
) -> str:
"""
@ -528,7 +526,7 @@ def generate_gbnf_grammar_from_pydantic_models(
* grammar.
Args:
models (List[Type[BaseModel]]): A list of Pydantic models to generate the grammar from.
models (list[type[BaseModel]]): A list of Pydantic models to generate the grammar from.
outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling.
outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling.
list_of_outputs (str, optional): Allows a list of output objects
@ -543,9 +541,9 @@ def generate_gbnf_grammar_from_pydantic_models(
# root ::= UserModel | PostModel
# ...
"""
processed_models = set()
processed_models: set[type[BaseModel]] = set()
all_rules = []
created_rules = {}
created_rules: dict[str, list[str]] = {}
if outer_object_name is None:
for model in models:
model_rules, _ = generate_gbnf_grammar(model, processed_models, created_rules)
@ -608,7 +606,7 @@ def get_primitive_grammar(grammar):
Returns:
str: GBNF primitive grammar string.
"""
type_list = []
type_list: list[type[object]] = []
if "string-list" in grammar:
type_list.append(str)
if "boolean-list" in grammar:
@ -666,14 +664,14 @@ triple-quotes ::= "'''" """
def generate_markdown_documentation(
pydantic_models: List[Type[BaseModel]], model_prefix="Model", fields_prefix="Fields",
pydantic_models: list[type[BaseModel]], model_prefix="Model", fields_prefix="Fields",
documentation_with_field_description=True
) -> str:
"""
Generate markdown documentation for a list of Pydantic models.
Args:
pydantic_models (List[Type[BaseModel]]): List of Pydantic model classes.
pydantic_models (list[type[BaseModel]]): list of Pydantic model classes.
model_prefix (str): Prefix for the model section.
fields_prefix (str): Prefix for the fields section.
documentation_with_field_description (bool): Include field descriptions in the documentation.
@ -731,7 +729,7 @@ def generate_markdown_documentation(
def generate_field_markdown(
field_name: str, field_type: Type[Any], model: Type[BaseModel], depth=1,
field_name: str, field_type: type[Any], model: type[BaseModel], depth=1,
documentation_with_field_description=True
) -> str:
"""
@ -739,8 +737,8 @@ def generate_field_markdown(
Args:
field_name (str): Name of the field.
field_type (Type[Any]): Type of the field.
model (Type[BaseModel]): Pydantic model class.
field_type (type[Any]): Type of the field.
model (type[BaseModel]): Pydantic model class.
depth (int): Indentation depth in the documentation.
documentation_with_field_description (bool): Include field descriptions in the documentation.
@ -798,7 +796,7 @@ def generate_field_markdown(
return field_text
def format_json_example(example: dict, depth: int) -> str:
def format_json_example(example: dict[str, Any], depth: int) -> str:
"""
Format a JSON example into a readable string with indentation.
@ -819,14 +817,14 @@ def format_json_example(example: dict, depth: int) -> str:
def generate_text_documentation(
pydantic_models: List[Type[BaseModel]], model_prefix="Model", fields_prefix="Fields",
pydantic_models: list[type[BaseModel]], model_prefix="Model", fields_prefix="Fields",
documentation_with_field_description=True
) -> str:
"""
Generate text documentation for a list of Pydantic models.
Args:
pydantic_models (List[Type[BaseModel]]): List of Pydantic model classes.
pydantic_models (list[type[BaseModel]]): List of Pydantic model classes.
model_prefix (str): Prefix for the model section.
fields_prefix (str): Prefix for the fields section.
documentation_with_field_description (bool): Include field descriptions in the documentation.
@ -885,7 +883,7 @@ def generate_text_documentation(
def generate_field_text(
field_name: str, field_type: Type[Any], model: Type[BaseModel], depth=1,
field_name: str, field_type: type[Any], model: type[BaseModel], depth=1,
documentation_with_field_description=True
) -> str:
"""
@ -893,8 +891,8 @@ def generate_field_text(
Args:
field_name (str): Name of the field.
field_type (Type[Any]): Type of the field.
model (Type[BaseModel]): Pydantic model class.
field_type (type[Any]): Type of the field.
model (type[BaseModel]): Pydantic model class.
depth (int): Indentation depth in the documentation.
documentation_with_field_description (bool): Include field descriptions in the documentation.
@ -1017,8 +1015,8 @@ def generate_and_save_gbnf_grammar_and_documentation(
pydantic_model_list,
grammar_file_path="./generated_grammar.gbnf",
documentation_file_path="./generated_grammar_documentation.md",
outer_object_name: str = None,
outer_object_content: str = None,
outer_object_name: str | None = None,
outer_object_content: str | None = None,
model_prefix: str = "Output Model",
fields_prefix: str = "Output Fields",
list_of_outputs: bool = False,
@ -1053,8 +1051,8 @@ def generate_and_save_gbnf_grammar_and_documentation(
def generate_gbnf_grammar_and_documentation(
pydantic_model_list,
outer_object_name: str = None,
outer_object_content: str = None,
outer_object_name: str | None = None,
outer_object_content: str | None = None,
model_prefix: str = "Output Model",
fields_prefix: str = "Output Fields",
list_of_outputs: bool = False,
@ -1086,9 +1084,9 @@ def generate_gbnf_grammar_and_documentation(
def generate_gbnf_grammar_and_documentation_from_dictionaries(
dictionaries: List[dict],
outer_object_name: str = None,
outer_object_content: str = None,
dictionaries: list[dict[str, Any]],
outer_object_name: str | None = None,
outer_object_content: str | None = None,
model_prefix: str = "Output Model",
fields_prefix: str = "Output Fields",
list_of_outputs: bool = False,
@ -1098,7 +1096,7 @@ def generate_gbnf_grammar_and_documentation_from_dictionaries(
Generate GBNF grammar and documentation from a list of dictionaries.
Args:
dictionaries (List[dict]): List of dictionaries representing Pydantic models.
dictionaries (list[dict]): List of dictionaries representing Pydantic models.
outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling.
outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling.
model_prefix (str): Prefix for the model section in the documentation.
@ -1120,7 +1118,7 @@ def generate_gbnf_grammar_and_documentation_from_dictionaries(
return grammar, documentation
def create_dynamic_model_from_function(func: Callable):
def create_dynamic_model_from_function(func: Callable[..., Any]):
"""
Creates a dynamic Pydantic model from a given function's type hints and adds the function as a 'run' method.
@ -1135,6 +1133,7 @@ def create_dynamic_model_from_function(func: Callable):
sig = inspect.signature(func)
# Parse the docstring
assert func.__doc__ is not None
docstring = parse(func.__doc__)
dynamic_fields = {}
@ -1157,7 +1156,6 @@ def create_dynamic_model_from_function(func: Callable):
f"Parameter '{param.name}' in function '{func.__name__}' lacks a description in the docstring")
# Add parameter details to the schema
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
param_docs.append((param.name, param_doc))
if param.default == inspect.Parameter.empty:
default_value = ...
@ -1166,10 +1164,10 @@ def create_dynamic_model_from_function(func: Callable):
dynamic_fields[param.name] = (
param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
# Creating the dynamic model
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields) # type: ignore[call-overload]
for param_doc in param_docs:
dynamic_model.model_fields[param_doc[0]].description = param_doc[1].description
for name, param_doc in param_docs:
dynamic_model.model_fields[name].description = param_doc.description
dynamic_model.__doc__ = docstring.short_description
@ -1182,16 +1180,16 @@ def create_dynamic_model_from_function(func: Callable):
return dynamic_model
def add_run_method_to_dynamic_model(model: Type[BaseModel], func: Callable):
def add_run_method_to_dynamic_model(model: type[BaseModel], func: Callable[..., Any]):
"""
Add a 'run' method to a dynamic Pydantic model, using the provided function.
Args:
model (Type[BaseModel]): Dynamic Pydantic model class.
model (type[BaseModel]): Dynamic Pydantic model class.
func (Callable): Function to be added as a 'run' method to the model.
Returns:
Type[BaseModel]: Pydantic model class with the added 'run' method.
type[BaseModel]: Pydantic model class with the added 'run' method.
"""
def run_method_wrapper(self):
@ -1204,15 +1202,15 @@ def add_run_method_to_dynamic_model(model: Type[BaseModel], func: Callable):
return model
def create_dynamic_models_from_dictionaries(dictionaries: List[dict]):
def create_dynamic_models_from_dictionaries(dictionaries: list[dict[str, Any]]):
"""
Create a list of dynamic Pydantic model classes from a list of dictionaries.
Args:
dictionaries (List[dict]): List of dictionaries representing model structures.
dictionaries (list[dict]): List of dictionaries representing model structures.
Returns:
List[Type[BaseModel]]: List of generated dynamic Pydantic model classes.
list[type[BaseModel]]: List of generated dynamic Pydantic model classes.
"""
dynamic_models = []
for func in dictionaries:
@ -1249,7 +1247,7 @@ def list_to_enum(enum_name, values):
return Enum(enum_name, {value: value for value in values})
def convert_dictionary_to_pydantic_model(dictionary: dict, model_name: str = "CustomModel") -> Type[BaseModel]:
def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name: str = "CustomModel") -> type[Any]:
"""
Convert a dictionary to a Pydantic model class.
@ -1258,9 +1256,9 @@ def convert_dictionary_to_pydantic_model(dictionary: dict, model_name: str = "Cu
model_name (str): Name of the generated Pydantic model.
Returns:
Type[BaseModel]: Generated Pydantic model class.
type[BaseModel]: Generated Pydantic model class.
"""
fields = {}
fields: dict[str, Any] = {}
if "properties" in dictionary:
for field_name, field_data in dictionary.get("properties", {}).items():
@ -1277,7 +1275,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict, model_name: str = "Cu
if items != {}:
array = {"properties": items}
array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
fields[field_name] = (List[array_type], ...)
fields[field_name] = (List[array_type], ...) # type: ignore[valid-type]
else:
fields[field_name] = (list, ...)
elif field_type == "object":