pydantic : replace uses of __annotations__ with get_type_hints

This commit is contained in:
Francis Couture-Harpin 2024-07-13 16:46:26 -04:00
parent 17eb6aa8a9
commit eed299f0d2
2 changed files with 10 additions and 8 deletions

View File

@ -6,7 +6,7 @@ import re
from copy import copy from copy import copy
from enum import Enum from enum import Enum
from inspect import getdoc, isclass from inspect import getdoc, isclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
from docstring_parser import parse from docstring_parser import parse
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
@ -118,7 +118,7 @@ def get_members_structure(cls, rule_name):
# Modify this comprehension # Modify this comprehension
members = [ members = [
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}' f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
for name, param_type in cls.__annotations__.items() for name, param_type in get_type_hints(cls).items()
if name != "self" if name != "self"
] ]
@ -462,7 +462,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
if not issubclass(model, BaseModel): if not issubclass(model, BaseModel):
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__ # For non-Pydantic classes, generate model_fields from __annotations__ or __init__
if hasattr(model, "__annotations__") and model.__annotations__: if hasattr(model, "__annotations__") and model.__annotations__:
model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues] model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()}
else: else:
init_signature = inspect.signature(model.__init__) init_signature = inspect.signature(model.__init__)
parameters = init_signature.parameters parameters = init_signature.parameters
@ -470,7 +470,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
name != "self"} name != "self"}
else: else:
# For Pydantic models, use model_fields and check for ellipsis (required fields) # For Pydantic models, use model_fields and check for ellipsis (required fields)
model_fields = model.__annotations__ model_fields = get_type_hints(model)
model_rule_parts = [] model_rule_parts = []
nested_rules = [] nested_rules = []
@ -706,7 +706,7 @@ def generate_markdown_documentation(
else: else:
documentation += f" Fields:\n" # noqa: F541 documentation += f" Fields:\n" # noqa: F541
if isclass(model) and issubclass(model, BaseModel): if isclass(model) and issubclass(model, BaseModel):
for name, field_type in model.__annotations__.items(): for name, field_type in get_type_hints(model).items():
# if name == "markdown_code_block": # if name == "markdown_code_block":
# continue # continue
if get_origin(field_type) == list: if get_origin(field_type) == list:
@ -794,7 +794,7 @@ def generate_field_markdown(
if isclass(field_type) and issubclass(field_type, BaseModel): if isclass(field_type) and issubclass(field_type, BaseModel):
field_text += f"{indent} Details:\n" field_text += f"{indent} Details:\n"
for name, type_ in field_type.__annotations__.items(): for name, type_ in get_type_hints(field_type).items():
field_text += generate_field_markdown(name, type_, field_type, depth + 2) field_text += generate_field_markdown(name, type_, field_type, depth + 2)
return field_text return field_text
@ -855,7 +855,7 @@ def generate_text_documentation(
if isclass(model) and issubclass(model, BaseModel): if isclass(model) and issubclass(model, BaseModel):
documentation_fields = "" documentation_fields = ""
for name, field_type in model.__annotations__.items(): for name, field_type in get_type_hints(model).items():
# if name == "markdown_code_block": # if name == "markdown_code_block":
# continue # continue
if get_origin(field_type) == list: if get_origin(field_type) == list:
@ -948,7 +948,7 @@ def generate_field_text(
if isclass(field_type) and issubclass(field_type, BaseModel): if isclass(field_type) and issubclass(field_type, BaseModel):
field_text += f"{indent} Details:\n" field_text += f"{indent} Details:\n"
for name, type_ in field_type.__annotations__.items(): for name, type_ in get_type_hints(field_type).items():
field_text += generate_field_text(name, type_, field_type, depth + 2) field_text += generate_field_text(name, type_, field_type, depth + 2)
return field_text return field_text

View File

@ -20,6 +20,8 @@ def create_completion(prompt, grammar):
response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data) response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data)
data = response.json() data = response.json()
assert data.get("error") is None, data
print(data["content"]) print(data["content"])
return data["content"] return data["content"]