From f89eaa921e481fe6aeb7ead98d41af516d46496d Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 13 Jul 2024 21:52:45 -0400 Subject: [PATCH] pydantic : fix Python 3.9 and 3.10 support --- examples/pydantic_models_to_grammar.py | 61 +++++++++++++++----------- requirements/requirements-pydantic.txt | 1 + 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/examples/pydantic_models_to_grammar.py b/examples/pydantic_models_to_grammar.py index cb62fa705..93e5dcb6c 100644 --- a/examples/pydantic_models_to_grammar.py +++ b/examples/pydantic_models_to_grammar.py @@ -53,35 +53,38 @@ class PydanticDataType(Enum): def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str: - if isclass(pydantic_type) and issubclass(pydantic_type, str): + origin_type = get_origin(pydantic_type) + origin_type = pydantic_type if origin_type is None else origin_type + + if isclass(origin_type) and issubclass(origin_type, str): return PydanticDataType.STRING.value - elif isclass(pydantic_type) and issubclass(pydantic_type, bool): + elif isclass(origin_type) and issubclass(origin_type, bool): return PydanticDataType.BOOLEAN.value - elif isclass(pydantic_type) and issubclass(pydantic_type, int): + elif isclass(origin_type) and issubclass(origin_type, int): return PydanticDataType.INTEGER.value - elif isclass(pydantic_type) and issubclass(pydantic_type, float): + elif isclass(origin_type) and issubclass(origin_type, float): return PydanticDataType.FLOAT.value - elif isclass(pydantic_type) and issubclass(pydantic_type, Enum): + elif isclass(origin_type) and issubclass(origin_type, Enum): return PydanticDataType.ENUM.value - elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel): - return format_model_and_field_name(pydantic_type.__name__) - elif get_origin(pydantic_type) is list: + elif isclass(origin_type) and issubclass(origin_type, BaseModel): + return format_model_and_field_name(origin_type.__name__) + elif origin_type is list: element_type = get_args(pydantic_type)[0] return f"{map_pydantic_type_to_gbnf(element_type)}-list" - elif get_origin(pydantic_type) is set: + elif origin_type is set: element_type = get_args(pydantic_type)[0] return f"{map_pydantic_type_to_gbnf(element_type)}-set" - elif get_origin(pydantic_type) is Union: + elif origin_type is Union: union_types = get_args(pydantic_type) union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types] return f"union-{'-or-'.join(union_rules)}" - elif get_origin(pydantic_type) is Optional: + elif origin_type is Optional: element_type = get_args(pydantic_type)[0] return f"optional-{map_pydantic_type_to_gbnf(element_type)}" - elif isclass(pydantic_type): - return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}" - elif get_origin(pydantic_type) is dict: + elif isclass(origin_type): + return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(origin_type.__name__)}" + elif origin_type is dict: key_type, value_type = get_args(pydantic_type) return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}" else: @@ -297,17 +300,20 @@ def generate_gbnf_rule_for_type( field_name = format_model_and_field_name(field_name) gbnf_type = map_pydantic_type_to_gbnf(field_type) - if isclass(field_type) and issubclass(field_type, BaseModel): + origin_type = get_origin(field_type) + origin_type = field_type if origin_type is None else origin_type + + if isclass(origin_type) and issubclass(origin_type, BaseModel): nested_model_name = format_model_and_field_name(field_type.__name__) nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules) rules.extend(nested_model_rules) gbnf_type, rules = nested_model_name, rules - elif isclass(field_type) and issubclass(field_type, Enum): + elif isclass(origin_type) and issubclass(origin_type, Enum): enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}" rules.append(enum_rule) gbnf_type, rules = model_name + "-" + field_name, rules - elif get_origin(field_type) == list: # Array + elif origin_type is list: # Array element_type = get_args(field_type)[0] element_rule_name, additional_rules = generate_gbnf_rule_for_type( model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules @@ -317,7 +323,7 @@ def generate_gbnf_rule_for_type( rules.append(array_rule) gbnf_type, rules = model_name + "-" + field_name, rules - elif get_origin(field_type) == set or field_type == set: # Array + elif origin_type is set: # Array element_type = get_args(field_type)[0] element_rule_name, additional_rules = generate_gbnf_rule_for_type( model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules @@ -371,7 +377,7 @@ def generate_gbnf_rule_for_type( gbnf_type = f"{model_name}-{field_name}-optional" else: gbnf_type = f"{model_name}-{field_name}-union" - elif isclass(field_type) and issubclass(field_type, str): + elif isclass(origin_type) and issubclass(origin_type, str): if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None: triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False) markdown_string = field_info.json_schema_extra.get("markdown_code_block", False) @@ -387,8 +393,8 @@ def generate_gbnf_rule_for_type( gbnf_type = PydanticDataType.STRING.value elif ( - isclass(field_type) - and issubclass(field_type, float) + isclass(origin_type) + and issubclass(origin_type, float) and field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None @@ -413,8 +419,8 @@ def generate_gbnf_rule_for_type( ) elif ( - isclass(field_type) - and issubclass(field_type, int) + isclass(origin_type) + and issubclass(origin_type, int) and field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None @@ -754,14 +760,17 @@ def generate_field_markdown( field_info = model.model_fields.get(field_name) field_description = field_info.description if field_info and field_info.description else "" - if get_origin(field_type) == list: + origin_type = get_origin(field_type) + origin_type = field_type if origin_type is None else origin_type + + if origin_type == list: element_type = get_args(field_type)[0] field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})" if field_description != "": field_text += ":\n" else: field_text += "\n" - elif get_origin(field_type) == Union: + elif origin_type == Union: element_types = get_args(field_type) types = [] for element_type in element_types: @@ -792,7 +801,7 @@ def generate_field_markdown( example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example field_text += f"{indent} Example: {example_text}\n" - if isclass(field_type) and issubclass(field_type, BaseModel): + if isclass(origin_type) and issubclass(origin_type, BaseModel): field_text += f"{indent} Details:\n" for name, type_ in get_type_hints(field_type).items(): field_text += generate_field_markdown(name, type_, field_type, depth + 2) diff --git a/requirements/requirements-pydantic.txt b/requirements/requirements-pydantic.txt index 2f9455b14..bdd423e07 100644 --- a/requirements/requirements-pydantic.txt +++ b/requirements/requirements-pydantic.txt @@ -1,2 +1,3 @@ docstring_parser~=0.15 pydantic~=2.6.3 +requests