mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
pydantic : replace uses of __annotations__ with get_type_hints
This commit is contained in:
parent
17eb6aa8a9
commit
eed299f0d2
@ -6,7 +6,7 @@ import re
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
from inspect import getdoc, isclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel, create_model
|
||||
@ -118,7 +118,7 @@ def get_members_structure(cls, rule_name):
|
||||
# Modify this comprehension
|
||||
members = [
|
||||
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
|
||||
for name, param_type in cls.__annotations__.items()
|
||||
for name, param_type in get_type_hints(cls).items()
|
||||
if name != "self"
|
||||
]
|
||||
|
||||
@ -462,7 +462,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
|
||||
if not issubclass(model, BaseModel):
|
||||
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__
|
||||
if hasattr(model, "__annotations__") and model.__annotations__:
|
||||
model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues]
|
||||
model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()}
|
||||
else:
|
||||
init_signature = inspect.signature(model.__init__)
|
||||
parameters = init_signature.parameters
|
||||
@ -470,7 +470,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
|
||||
name != "self"}
|
||||
else:
|
||||
# For Pydantic models, use model_fields and check for ellipsis (required fields)
|
||||
model_fields = model.__annotations__
|
||||
model_fields = get_type_hints(model)
|
||||
|
||||
model_rule_parts = []
|
||||
nested_rules = []
|
||||
@ -706,7 +706,7 @@ def generate_markdown_documentation(
|
||||
else:
|
||||
documentation += f" Fields:\n" # noqa: F541
|
||||
if isclass(model) and issubclass(model, BaseModel):
|
||||
for name, field_type in model.__annotations__.items():
|
||||
for name, field_type in get_type_hints(model).items():
|
||||
# if name == "markdown_code_block":
|
||||
# continue
|
||||
if get_origin(field_type) == list:
|
||||
@ -794,7 +794,7 @@ def generate_field_markdown(
|
||||
|
||||
if isclass(field_type) and issubclass(field_type, BaseModel):
|
||||
field_text += f"{indent} Details:\n"
|
||||
for name, type_ in field_type.__annotations__.items():
|
||||
for name, type_ in get_type_hints(field_type).items():
|
||||
field_text += generate_field_markdown(name, type_, field_type, depth + 2)
|
||||
|
||||
return field_text
|
||||
@ -855,7 +855,7 @@ def generate_text_documentation(
|
||||
|
||||
if isclass(model) and issubclass(model, BaseModel):
|
||||
documentation_fields = ""
|
||||
for name, field_type in model.__annotations__.items():
|
||||
for name, field_type in get_type_hints(model).items():
|
||||
# if name == "markdown_code_block":
|
||||
# continue
|
||||
if get_origin(field_type) == list:
|
||||
@ -948,7 +948,7 @@ def generate_field_text(
|
||||
|
||||
if isclass(field_type) and issubclass(field_type, BaseModel):
|
||||
field_text += f"{indent} Details:\n"
|
||||
for name, type_ in field_type.__annotations__.items():
|
||||
for name, type_ in get_type_hints(field_type).items():
|
||||
field_text += generate_field_text(name, type_, field_type, depth + 2)
|
||||
|
||||
return field_text
|
||||
|
@ -20,6 +20,8 @@ def create_completion(prompt, grammar):
|
||||
response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data)
|
||||
data = response.json()
|
||||
|
||||
assert data.get("error") is None, data
|
||||
|
||||
print(data["content"])
|
||||
return data["content"]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user