pydantic : replace uses of __annotations__ with get_type_hints (#8474)

* pydantic : replace uses of __annotations__ with get_type_hints

* pydantic : fix Python 3.9 and 3.10 support
This commit is contained in:
compilade 2024-07-14 19:51:21 -04:00 committed by GitHub
parent aaab2419ea
commit 090fca7a07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 34 deletions

View File

@ -6,7 +6,7 @@ import re
from copy import copy from copy import copy
from enum import Enum from enum import Enum
from inspect import getdoc, isclass from inspect import getdoc, isclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
from docstring_parser import parse from docstring_parser import parse
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
@ -53,35 +53,38 @@ class PydanticDataType(Enum):
def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str: def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str:
if isclass(pydantic_type) and issubclass(pydantic_type, str): origin_type = get_origin(pydantic_type)
origin_type = pydantic_type if origin_type is None else origin_type
if isclass(origin_type) and issubclass(origin_type, str):
return PydanticDataType.STRING.value return PydanticDataType.STRING.value
elif isclass(pydantic_type) and issubclass(pydantic_type, bool): elif isclass(origin_type) and issubclass(origin_type, bool):
return PydanticDataType.BOOLEAN.value return PydanticDataType.BOOLEAN.value
elif isclass(pydantic_type) and issubclass(pydantic_type, int): elif isclass(origin_type) and issubclass(origin_type, int):
return PydanticDataType.INTEGER.value return PydanticDataType.INTEGER.value
elif isclass(pydantic_type) and issubclass(pydantic_type, float): elif isclass(origin_type) and issubclass(origin_type, float):
return PydanticDataType.FLOAT.value return PydanticDataType.FLOAT.value
elif isclass(pydantic_type) and issubclass(pydantic_type, Enum): elif isclass(origin_type) and issubclass(origin_type, Enum):
return PydanticDataType.ENUM.value return PydanticDataType.ENUM.value
elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel): elif isclass(origin_type) and issubclass(origin_type, BaseModel):
return format_model_and_field_name(pydantic_type.__name__) return format_model_and_field_name(origin_type.__name__)
elif get_origin(pydantic_type) is list: elif origin_type is list:
element_type = get_args(pydantic_type)[0] element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-list" return f"{map_pydantic_type_to_gbnf(element_type)}-list"
elif get_origin(pydantic_type) is set: elif origin_type is set:
element_type = get_args(pydantic_type)[0] element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-set" return f"{map_pydantic_type_to_gbnf(element_type)}-set"
elif get_origin(pydantic_type) is Union: elif origin_type is Union:
union_types = get_args(pydantic_type) union_types = get_args(pydantic_type)
union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types] union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
return f"union-{'-or-'.join(union_rules)}" return f"union-{'-or-'.join(union_rules)}"
elif get_origin(pydantic_type) is Optional: elif origin_type is Optional:
element_type = get_args(pydantic_type)[0] element_type = get_args(pydantic_type)[0]
return f"optional-{map_pydantic_type_to_gbnf(element_type)}" return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
elif isclass(pydantic_type): elif isclass(origin_type):
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}" return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(origin_type.__name__)}"
elif get_origin(pydantic_type) is dict: elif origin_type is dict:
key_type, value_type = get_args(pydantic_type) key_type, value_type = get_args(pydantic_type)
return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}" return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
else: else:
@ -118,7 +121,7 @@ def get_members_structure(cls, rule_name):
# Modify this comprehension # Modify this comprehension
members = [ members = [
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}' f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
for name, param_type in cls.__annotations__.items() for name, param_type in get_type_hints(cls).items()
if name != "self" if name != "self"
] ]
@ -297,17 +300,20 @@ def generate_gbnf_rule_for_type(
field_name = format_model_and_field_name(field_name) field_name = format_model_and_field_name(field_name)
gbnf_type = map_pydantic_type_to_gbnf(field_type) gbnf_type = map_pydantic_type_to_gbnf(field_type)
if isclass(field_type) and issubclass(field_type, BaseModel): origin_type = get_origin(field_type)
origin_type = field_type if origin_type is None else origin_type
if isclass(origin_type) and issubclass(origin_type, BaseModel):
nested_model_name = format_model_and_field_name(field_type.__name__) nested_model_name = format_model_and_field_name(field_type.__name__)
nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules) nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules)
rules.extend(nested_model_rules) rules.extend(nested_model_rules)
gbnf_type, rules = nested_model_name, rules gbnf_type, rules = nested_model_name, rules
elif isclass(field_type) and issubclass(field_type, Enum): elif isclass(origin_type) and issubclass(origin_type, Enum):
enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes
enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}" enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
rules.append(enum_rule) rules.append(enum_rule)
gbnf_type, rules = model_name + "-" + field_name, rules gbnf_type, rules = model_name + "-" + field_name, rules
elif get_origin(field_type) == list: # Array elif origin_type is list: # Array
element_type = get_args(field_type)[0] element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type( element_rule_name, additional_rules = generate_gbnf_rule_for_type(
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@ -317,7 +323,7 @@ def generate_gbnf_rule_for_type(
rules.append(array_rule) rules.append(array_rule)
gbnf_type, rules = model_name + "-" + field_name, rules gbnf_type, rules = model_name + "-" + field_name, rules
elif get_origin(field_type) == set or field_type == set: # Array elif origin_type is set: # Array
element_type = get_args(field_type)[0] element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type( element_rule_name, additional_rules = generate_gbnf_rule_for_type(
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@ -371,7 +377,7 @@ def generate_gbnf_rule_for_type(
gbnf_type = f"{model_name}-{field_name}-optional" gbnf_type = f"{model_name}-{field_name}-optional"
else: else:
gbnf_type = f"{model_name}-{field_name}-union" gbnf_type = f"{model_name}-{field_name}-union"
elif isclass(field_type) and issubclass(field_type, str): elif isclass(origin_type) and issubclass(origin_type, str):
if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None: if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None:
triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False) triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False)
markdown_string = field_info.json_schema_extra.get("markdown_code_block", False) markdown_string = field_info.json_schema_extra.get("markdown_code_block", False)
@ -387,8 +393,8 @@ def generate_gbnf_rule_for_type(
gbnf_type = PydanticDataType.STRING.value gbnf_type = PydanticDataType.STRING.value
elif ( elif (
isclass(field_type) isclass(origin_type)
and issubclass(field_type, float) and issubclass(origin_type, float)
and field_info and field_info
and hasattr(field_info, "json_schema_extra") and hasattr(field_info, "json_schema_extra")
and field_info.json_schema_extra is not None and field_info.json_schema_extra is not None
@ -413,8 +419,8 @@ def generate_gbnf_rule_for_type(
) )
elif ( elif (
isclass(field_type) isclass(origin_type)
and issubclass(field_type, int) and issubclass(origin_type, int)
and field_info and field_info
and hasattr(field_info, "json_schema_extra") and hasattr(field_info, "json_schema_extra")
and field_info.json_schema_extra is not None and field_info.json_schema_extra is not None
@ -462,7 +468,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
if not issubclass(model, BaseModel): if not issubclass(model, BaseModel):
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__ # For non-Pydantic classes, generate model_fields from __annotations__ or __init__
if hasattr(model, "__annotations__") and model.__annotations__: if hasattr(model, "__annotations__") and model.__annotations__:
model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues] model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()}
else: else:
init_signature = inspect.signature(model.__init__) init_signature = inspect.signature(model.__init__)
parameters = init_signature.parameters parameters = init_signature.parameters
@ -470,7 +476,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
name != "self"} name != "self"}
else: else:
# For Pydantic models, use model_fields and check for ellipsis (required fields) # For Pydantic models, use model_fields and check for ellipsis (required fields)
model_fields = model.__annotations__ model_fields = get_type_hints(model)
model_rule_parts = [] model_rule_parts = []
nested_rules = [] nested_rules = []
@ -706,7 +712,7 @@ def generate_markdown_documentation(
else: else:
documentation += f" Fields:\n" # noqa: F541 documentation += f" Fields:\n" # noqa: F541
if isclass(model) and issubclass(model, BaseModel): if isclass(model) and issubclass(model, BaseModel):
for name, field_type in model.__annotations__.items(): for name, field_type in get_type_hints(model).items():
# if name == "markdown_code_block": # if name == "markdown_code_block":
# continue # continue
if get_origin(field_type) == list: if get_origin(field_type) == list:
@ -754,14 +760,17 @@ def generate_field_markdown(
field_info = model.model_fields.get(field_name) field_info = model.model_fields.get(field_name)
field_description = field_info.description if field_info and field_info.description else "" field_description = field_info.description if field_info and field_info.description else ""
if get_origin(field_type) == list: origin_type = get_origin(field_type)
origin_type = field_type if origin_type is None else origin_type
if origin_type == list:
element_type = get_args(field_type)[0] element_type = get_args(field_type)[0]
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})" field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
if field_description != "": if field_description != "":
field_text += ":\n" field_text += ":\n"
else: else:
field_text += "\n" field_text += "\n"
elif get_origin(field_type) == Union: elif origin_type == Union:
element_types = get_args(field_type) element_types = get_args(field_type)
types = [] types = []
for element_type in element_types: for element_type in element_types:
@ -792,9 +801,9 @@ def generate_field_markdown(
example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example
field_text += f"{indent} Example: {example_text}\n" field_text += f"{indent} Example: {example_text}\n"
if isclass(field_type) and issubclass(field_type, BaseModel): if isclass(origin_type) and issubclass(origin_type, BaseModel):
field_text += f"{indent} Details:\n" field_text += f"{indent} Details:\n"
for name, type_ in field_type.__annotations__.items(): for name, type_ in get_type_hints(field_type).items():
field_text += generate_field_markdown(name, type_, field_type, depth + 2) field_text += generate_field_markdown(name, type_, field_type, depth + 2)
return field_text return field_text
@ -855,7 +864,7 @@ def generate_text_documentation(
if isclass(model) and issubclass(model, BaseModel): if isclass(model) and issubclass(model, BaseModel):
documentation_fields = "" documentation_fields = ""
for name, field_type in model.__annotations__.items(): for name, field_type in get_type_hints(model).items():
# if name == "markdown_code_block": # if name == "markdown_code_block":
# continue # continue
if get_origin(field_type) == list: if get_origin(field_type) == list:
@ -948,7 +957,7 @@ def generate_field_text(
if isclass(field_type) and issubclass(field_type, BaseModel): if isclass(field_type) and issubclass(field_type, BaseModel):
field_text += f"{indent} Details:\n" field_text += f"{indent} Details:\n"
for name, type_ in field_type.__annotations__.items(): for name, type_ in get_type_hints(field_type).items():
field_text += generate_field_text(name, type_, field_type, depth + 2) field_text += generate_field_text(name, type_, field_type, depth + 2)
return field_text return field_text

View File

@ -20,6 +20,8 @@ def create_completion(prompt, grammar):
response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data) response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data)
data = response.json() data = response.json()
assert data.get("error") is None, data
print(data["content"]) print(data["content"])
return data["content"] return data["content"]

View File

@ -1,2 +1,3 @@
docstring_parser~=0.15 docstring_parser~=0.15
pydantic~=2.6.3 pydantic~=2.6.3
requests