mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
pydantic : replace uses of __annotations__ with get_type_hints (#8474)
* pydantic : replace uses of __annotations__ with get_type_hints * pydantic : fix Python 3.9 and 3.10 support
This commit is contained in:
parent
aaab2419ea
commit
090fca7a07
@ -6,7 +6,7 @@ import re
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from inspect import getdoc, isclass
|
from inspect import getdoc, isclass
|
||||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin
|
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
from docstring_parser import parse
|
from docstring_parser import parse
|
||||||
from pydantic import BaseModel, create_model
|
from pydantic import BaseModel, create_model
|
||||||
@ -53,35 +53,38 @@ class PydanticDataType(Enum):
|
|||||||
|
|
||||||
|
|
||||||
def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str:
|
def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str:
|
||||||
if isclass(pydantic_type) and issubclass(pydantic_type, str):
|
origin_type = get_origin(pydantic_type)
|
||||||
|
origin_type = pydantic_type if origin_type is None else origin_type
|
||||||
|
|
||||||
|
if isclass(origin_type) and issubclass(origin_type, str):
|
||||||
return PydanticDataType.STRING.value
|
return PydanticDataType.STRING.value
|
||||||
elif isclass(pydantic_type) and issubclass(pydantic_type, bool):
|
elif isclass(origin_type) and issubclass(origin_type, bool):
|
||||||
return PydanticDataType.BOOLEAN.value
|
return PydanticDataType.BOOLEAN.value
|
||||||
elif isclass(pydantic_type) and issubclass(pydantic_type, int):
|
elif isclass(origin_type) and issubclass(origin_type, int):
|
||||||
return PydanticDataType.INTEGER.value
|
return PydanticDataType.INTEGER.value
|
||||||
elif isclass(pydantic_type) and issubclass(pydantic_type, float):
|
elif isclass(origin_type) and issubclass(origin_type, float):
|
||||||
return PydanticDataType.FLOAT.value
|
return PydanticDataType.FLOAT.value
|
||||||
elif isclass(pydantic_type) and issubclass(pydantic_type, Enum):
|
elif isclass(origin_type) and issubclass(origin_type, Enum):
|
||||||
return PydanticDataType.ENUM.value
|
return PydanticDataType.ENUM.value
|
||||||
|
|
||||||
elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel):
|
elif isclass(origin_type) and issubclass(origin_type, BaseModel):
|
||||||
return format_model_and_field_name(pydantic_type.__name__)
|
return format_model_and_field_name(origin_type.__name__)
|
||||||
elif get_origin(pydantic_type) is list:
|
elif origin_type is list:
|
||||||
element_type = get_args(pydantic_type)[0]
|
element_type = get_args(pydantic_type)[0]
|
||||||
return f"{map_pydantic_type_to_gbnf(element_type)}-list"
|
return f"{map_pydantic_type_to_gbnf(element_type)}-list"
|
||||||
elif get_origin(pydantic_type) is set:
|
elif origin_type is set:
|
||||||
element_type = get_args(pydantic_type)[0]
|
element_type = get_args(pydantic_type)[0]
|
||||||
return f"{map_pydantic_type_to_gbnf(element_type)}-set"
|
return f"{map_pydantic_type_to_gbnf(element_type)}-set"
|
||||||
elif get_origin(pydantic_type) is Union:
|
elif origin_type is Union:
|
||||||
union_types = get_args(pydantic_type)
|
union_types = get_args(pydantic_type)
|
||||||
union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
|
union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
|
||||||
return f"union-{'-or-'.join(union_rules)}"
|
return f"union-{'-or-'.join(union_rules)}"
|
||||||
elif get_origin(pydantic_type) is Optional:
|
elif origin_type is Optional:
|
||||||
element_type = get_args(pydantic_type)[0]
|
element_type = get_args(pydantic_type)[0]
|
||||||
return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
|
return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
|
||||||
elif isclass(pydantic_type):
|
elif isclass(origin_type):
|
||||||
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}"
|
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(origin_type.__name__)}"
|
||||||
elif get_origin(pydantic_type) is dict:
|
elif origin_type is dict:
|
||||||
key_type, value_type = get_args(pydantic_type)
|
key_type, value_type = get_args(pydantic_type)
|
||||||
return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
|
return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
|
||||||
else:
|
else:
|
||||||
@ -118,7 +121,7 @@ def get_members_structure(cls, rule_name):
|
|||||||
# Modify this comprehension
|
# Modify this comprehension
|
||||||
members = [
|
members = [
|
||||||
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
|
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
|
||||||
for name, param_type in cls.__annotations__.items()
|
for name, param_type in get_type_hints(cls).items()
|
||||||
if name != "self"
|
if name != "self"
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -297,17 +300,20 @@ def generate_gbnf_rule_for_type(
|
|||||||
field_name = format_model_and_field_name(field_name)
|
field_name = format_model_and_field_name(field_name)
|
||||||
gbnf_type = map_pydantic_type_to_gbnf(field_type)
|
gbnf_type = map_pydantic_type_to_gbnf(field_type)
|
||||||
|
|
||||||
if isclass(field_type) and issubclass(field_type, BaseModel):
|
origin_type = get_origin(field_type)
|
||||||
|
origin_type = field_type if origin_type is None else origin_type
|
||||||
|
|
||||||
|
if isclass(origin_type) and issubclass(origin_type, BaseModel):
|
||||||
nested_model_name = format_model_and_field_name(field_type.__name__)
|
nested_model_name = format_model_and_field_name(field_type.__name__)
|
||||||
nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules)
|
nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules)
|
||||||
rules.extend(nested_model_rules)
|
rules.extend(nested_model_rules)
|
||||||
gbnf_type, rules = nested_model_name, rules
|
gbnf_type, rules = nested_model_name, rules
|
||||||
elif isclass(field_type) and issubclass(field_type, Enum):
|
elif isclass(origin_type) and issubclass(origin_type, Enum):
|
||||||
enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes
|
enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes
|
||||||
enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
|
enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
|
||||||
rules.append(enum_rule)
|
rules.append(enum_rule)
|
||||||
gbnf_type, rules = model_name + "-" + field_name, rules
|
gbnf_type, rules = model_name + "-" + field_name, rules
|
||||||
elif get_origin(field_type) == list: # Array
|
elif origin_type is list: # Array
|
||||||
element_type = get_args(field_type)[0]
|
element_type = get_args(field_type)[0]
|
||||||
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
|
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
|
||||||
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
|
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
|
||||||
@ -317,7 +323,7 @@ def generate_gbnf_rule_for_type(
|
|||||||
rules.append(array_rule)
|
rules.append(array_rule)
|
||||||
gbnf_type, rules = model_name + "-" + field_name, rules
|
gbnf_type, rules = model_name + "-" + field_name, rules
|
||||||
|
|
||||||
elif get_origin(field_type) == set or field_type == set: # Array
|
elif origin_type is set: # Array
|
||||||
element_type = get_args(field_type)[0]
|
element_type = get_args(field_type)[0]
|
||||||
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
|
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
|
||||||
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
|
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
|
||||||
@ -371,7 +377,7 @@ def generate_gbnf_rule_for_type(
|
|||||||
gbnf_type = f"{model_name}-{field_name}-optional"
|
gbnf_type = f"{model_name}-{field_name}-optional"
|
||||||
else:
|
else:
|
||||||
gbnf_type = f"{model_name}-{field_name}-union"
|
gbnf_type = f"{model_name}-{field_name}-union"
|
||||||
elif isclass(field_type) and issubclass(field_type, str):
|
elif isclass(origin_type) and issubclass(origin_type, str):
|
||||||
if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None:
|
if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None:
|
||||||
triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False)
|
triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False)
|
||||||
markdown_string = field_info.json_schema_extra.get("markdown_code_block", False)
|
markdown_string = field_info.json_schema_extra.get("markdown_code_block", False)
|
||||||
@ -387,8 +393,8 @@ def generate_gbnf_rule_for_type(
|
|||||||
gbnf_type = PydanticDataType.STRING.value
|
gbnf_type = PydanticDataType.STRING.value
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
isclass(field_type)
|
isclass(origin_type)
|
||||||
and issubclass(field_type, float)
|
and issubclass(origin_type, float)
|
||||||
and field_info
|
and field_info
|
||||||
and hasattr(field_info, "json_schema_extra")
|
and hasattr(field_info, "json_schema_extra")
|
||||||
and field_info.json_schema_extra is not None
|
and field_info.json_schema_extra is not None
|
||||||
@ -413,8 +419,8 @@ def generate_gbnf_rule_for_type(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
isclass(field_type)
|
isclass(origin_type)
|
||||||
and issubclass(field_type, int)
|
and issubclass(origin_type, int)
|
||||||
and field_info
|
and field_info
|
||||||
and hasattr(field_info, "json_schema_extra")
|
and hasattr(field_info, "json_schema_extra")
|
||||||
and field_info.json_schema_extra is not None
|
and field_info.json_schema_extra is not None
|
||||||
@ -462,7 +468,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
|
|||||||
if not issubclass(model, BaseModel):
|
if not issubclass(model, BaseModel):
|
||||||
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__
|
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__
|
||||||
if hasattr(model, "__annotations__") and model.__annotations__:
|
if hasattr(model, "__annotations__") and model.__annotations__:
|
||||||
model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues]
|
model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()}
|
||||||
else:
|
else:
|
||||||
init_signature = inspect.signature(model.__init__)
|
init_signature = inspect.signature(model.__init__)
|
||||||
parameters = init_signature.parameters
|
parameters = init_signature.parameters
|
||||||
@ -470,7 +476,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
|
|||||||
name != "self"}
|
name != "self"}
|
||||||
else:
|
else:
|
||||||
# For Pydantic models, use model_fields and check for ellipsis (required fields)
|
# For Pydantic models, use model_fields and check for ellipsis (required fields)
|
||||||
model_fields = model.__annotations__
|
model_fields = get_type_hints(model)
|
||||||
|
|
||||||
model_rule_parts = []
|
model_rule_parts = []
|
||||||
nested_rules = []
|
nested_rules = []
|
||||||
@ -706,7 +712,7 @@ def generate_markdown_documentation(
|
|||||||
else:
|
else:
|
||||||
documentation += f" Fields:\n" # noqa: F541
|
documentation += f" Fields:\n" # noqa: F541
|
||||||
if isclass(model) and issubclass(model, BaseModel):
|
if isclass(model) and issubclass(model, BaseModel):
|
||||||
for name, field_type in model.__annotations__.items():
|
for name, field_type in get_type_hints(model).items():
|
||||||
# if name == "markdown_code_block":
|
# if name == "markdown_code_block":
|
||||||
# continue
|
# continue
|
||||||
if get_origin(field_type) == list:
|
if get_origin(field_type) == list:
|
||||||
@ -754,14 +760,17 @@ def generate_field_markdown(
|
|||||||
field_info = model.model_fields.get(field_name)
|
field_info = model.model_fields.get(field_name)
|
||||||
field_description = field_info.description if field_info and field_info.description else ""
|
field_description = field_info.description if field_info and field_info.description else ""
|
||||||
|
|
||||||
if get_origin(field_type) == list:
|
origin_type = get_origin(field_type)
|
||||||
|
origin_type = field_type if origin_type is None else origin_type
|
||||||
|
|
||||||
|
if origin_type == list:
|
||||||
element_type = get_args(field_type)[0]
|
element_type = get_args(field_type)[0]
|
||||||
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
|
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
|
||||||
if field_description != "":
|
if field_description != "":
|
||||||
field_text += ":\n"
|
field_text += ":\n"
|
||||||
else:
|
else:
|
||||||
field_text += "\n"
|
field_text += "\n"
|
||||||
elif get_origin(field_type) == Union:
|
elif origin_type == Union:
|
||||||
element_types = get_args(field_type)
|
element_types = get_args(field_type)
|
||||||
types = []
|
types = []
|
||||||
for element_type in element_types:
|
for element_type in element_types:
|
||||||
@ -792,9 +801,9 @@ def generate_field_markdown(
|
|||||||
example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example
|
example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example
|
||||||
field_text += f"{indent} Example: {example_text}\n"
|
field_text += f"{indent} Example: {example_text}\n"
|
||||||
|
|
||||||
if isclass(field_type) and issubclass(field_type, BaseModel):
|
if isclass(origin_type) and issubclass(origin_type, BaseModel):
|
||||||
field_text += f"{indent} Details:\n"
|
field_text += f"{indent} Details:\n"
|
||||||
for name, type_ in field_type.__annotations__.items():
|
for name, type_ in get_type_hints(field_type).items():
|
||||||
field_text += generate_field_markdown(name, type_, field_type, depth + 2)
|
field_text += generate_field_markdown(name, type_, field_type, depth + 2)
|
||||||
|
|
||||||
return field_text
|
return field_text
|
||||||
@ -855,7 +864,7 @@ def generate_text_documentation(
|
|||||||
|
|
||||||
if isclass(model) and issubclass(model, BaseModel):
|
if isclass(model) and issubclass(model, BaseModel):
|
||||||
documentation_fields = ""
|
documentation_fields = ""
|
||||||
for name, field_type in model.__annotations__.items():
|
for name, field_type in get_type_hints(model).items():
|
||||||
# if name == "markdown_code_block":
|
# if name == "markdown_code_block":
|
||||||
# continue
|
# continue
|
||||||
if get_origin(field_type) == list:
|
if get_origin(field_type) == list:
|
||||||
@ -948,7 +957,7 @@ def generate_field_text(
|
|||||||
|
|
||||||
if isclass(field_type) and issubclass(field_type, BaseModel):
|
if isclass(field_type) and issubclass(field_type, BaseModel):
|
||||||
field_text += f"{indent} Details:\n"
|
field_text += f"{indent} Details:\n"
|
||||||
for name, type_ in field_type.__annotations__.items():
|
for name, type_ in get_type_hints(field_type).items():
|
||||||
field_text += generate_field_text(name, type_, field_type, depth + 2)
|
field_text += generate_field_text(name, type_, field_type, depth + 2)
|
||||||
|
|
||||||
return field_text
|
return field_text
|
||||||
|
@ -20,6 +20,8 @@ def create_completion(prompt, grammar):
|
|||||||
response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data)
|
response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
|
assert data.get("error") is None, data
|
||||||
|
|
||||||
print(data["content"])
|
print(data["content"])
|
||||||
return data["content"]
|
return data["content"]
|
||||||
|
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
docstring_parser~=0.15
|
docstring_parser~=0.15
|
||||||
pydantic~=2.6.3
|
pydantic~=2.6.3
|
||||||
|
requests
|
||||||
|
Loading…
Reference in New Issue
Block a user