mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
pydantic : replace uses of __annotations__ with get_type_hints
This commit is contained in:
parent
17eb6aa8a9
commit
eed299f0d2
@ -6,7 +6,7 @@ import re
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from inspect import getdoc, isclass
|
from inspect import getdoc, isclass
|
||||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin
|
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
from docstring_parser import parse
|
from docstring_parser import parse
|
||||||
from pydantic import BaseModel, create_model
|
from pydantic import BaseModel, create_model
|
||||||
@ -118,7 +118,7 @@ def get_members_structure(cls, rule_name):
|
|||||||
# Modify this comprehension
|
# Modify this comprehension
|
||||||
members = [
|
members = [
|
||||||
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
|
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
|
||||||
for name, param_type in cls.__annotations__.items()
|
for name, param_type in get_type_hints(cls).items()
|
||||||
if name != "self"
|
if name != "self"
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -462,7 +462,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
|
|||||||
if not issubclass(model, BaseModel):
|
if not issubclass(model, BaseModel):
|
||||||
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__
|
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__
|
||||||
if hasattr(model, "__annotations__") and model.__annotations__:
|
if hasattr(model, "__annotations__") and model.__annotations__:
|
||||||
model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues]
|
model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()}
|
||||||
else:
|
else:
|
||||||
init_signature = inspect.signature(model.__init__)
|
init_signature = inspect.signature(model.__init__)
|
||||||
parameters = init_signature.parameters
|
parameters = init_signature.parameters
|
||||||
@ -470,7 +470,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
|
|||||||
name != "self"}
|
name != "self"}
|
||||||
else:
|
else:
|
||||||
# For Pydantic models, use model_fields and check for ellipsis (required fields)
|
# For Pydantic models, use model_fields and check for ellipsis (required fields)
|
||||||
model_fields = model.__annotations__
|
model_fields = get_type_hints(model)
|
||||||
|
|
||||||
model_rule_parts = []
|
model_rule_parts = []
|
||||||
nested_rules = []
|
nested_rules = []
|
||||||
@ -706,7 +706,7 @@ def generate_markdown_documentation(
|
|||||||
else:
|
else:
|
||||||
documentation += f" Fields:\n" # noqa: F541
|
documentation += f" Fields:\n" # noqa: F541
|
||||||
if isclass(model) and issubclass(model, BaseModel):
|
if isclass(model) and issubclass(model, BaseModel):
|
||||||
for name, field_type in model.__annotations__.items():
|
for name, field_type in get_type_hints(model).items():
|
||||||
# if name == "markdown_code_block":
|
# if name == "markdown_code_block":
|
||||||
# continue
|
# continue
|
||||||
if get_origin(field_type) == list:
|
if get_origin(field_type) == list:
|
||||||
@ -794,7 +794,7 @@ def generate_field_markdown(
|
|||||||
|
|
||||||
if isclass(field_type) and issubclass(field_type, BaseModel):
|
if isclass(field_type) and issubclass(field_type, BaseModel):
|
||||||
field_text += f"{indent} Details:\n"
|
field_text += f"{indent} Details:\n"
|
||||||
for name, type_ in field_type.__annotations__.items():
|
for name, type_ in get_type_hints(field_type).items():
|
||||||
field_text += generate_field_markdown(name, type_, field_type, depth + 2)
|
field_text += generate_field_markdown(name, type_, field_type, depth + 2)
|
||||||
|
|
||||||
return field_text
|
return field_text
|
||||||
@ -855,7 +855,7 @@ def generate_text_documentation(
|
|||||||
|
|
||||||
if isclass(model) and issubclass(model, BaseModel):
|
if isclass(model) and issubclass(model, BaseModel):
|
||||||
documentation_fields = ""
|
documentation_fields = ""
|
||||||
for name, field_type in model.__annotations__.items():
|
for name, field_type in get_type_hints(model).items():
|
||||||
# if name == "markdown_code_block":
|
# if name == "markdown_code_block":
|
||||||
# continue
|
# continue
|
||||||
if get_origin(field_type) == list:
|
if get_origin(field_type) == list:
|
||||||
@ -948,7 +948,7 @@ def generate_field_text(
|
|||||||
|
|
||||||
if isclass(field_type) and issubclass(field_type, BaseModel):
|
if isclass(field_type) and issubclass(field_type, BaseModel):
|
||||||
field_text += f"{indent} Details:\n"
|
field_text += f"{indent} Details:\n"
|
||||||
for name, type_ in field_type.__annotations__.items():
|
for name, type_ in get_type_hints(field_type).items():
|
||||||
field_text += generate_field_text(name, type_, field_type, depth + 2)
|
field_text += generate_field_text(name, type_, field_type, depth + 2)
|
||||||
|
|
||||||
return field_text
|
return field_text
|
||||||
|
@ -20,6 +20,8 @@ def create_completion(prompt, grammar):
|
|||||||
response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data)
|
response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
|
assert data.get("error") is None, data
|
||||||
|
|
||||||
print(data["content"])
|
print(data["content"])
|
||||||
return data["content"]
|
return data["content"]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user