|
1 | 1 | import abc |
2 | 2 | import inspect |
3 | | -from typing import Union, Type, Dict |
| 3 | +from typing import Union, Type, Dict, Callable, Any |
4 | 4 |
|
5 | 5 | from pydantic import create_model, BaseModel |
6 | 6 |
|
@@ -78,3 +78,95 @@ class Model(Action): |
78 | 78 |
|
79 | 79 | return Model |
80 | 80 |
|
| 81 | + |
| 82 | +SimpleTextHandlerResult = tuple[bool, Union[str, tuple[str, str, ...]]] |
| 83 | +SimpleTextHandler = Callable[[str], SimpleTextHandlerResult] |
| 84 | + |
| 85 | + |
| 86 | +def capabilities_to_simple_text_handler(capabilities: Dict[str, Capability], default_capability: Capability = None, include_description: bool = True) -> tuple[Dict[str, str], SimpleTextHandler]: |
| 87 | + """ |
| 88 | + This function generates a simple text handler from a set of capabilities. |
| 89 | + It is to be used when no function calling is available, and structured output is not to be trusted, which is why it |
| 90 | + only supports the most basic of parameter types for the capabilities (str, int, float, bool). |
| 91 | +
|
| 92 | + As result it returns a dictionary of capability names to their descriptions and a parser function that can be used |
| 93 | + to parse the text input and execute it. The first return value of the parser function is a boolean indicating |
| 94 | + whether the parsing was successful, the second return value is a tuple containing the capability name, the parameters |
| 95 | + as a string and the result of the capability execution. |
| 96 | + """ |
| 97 | + def get_simple_fields(func, name) -> Dict[str, Type]: |
| 98 | + sig = inspect.signature(func) |
| 99 | + fields = {param: param_info.annotation for param, param_info in sig.parameters.items()} |
| 100 | + for param, param_type in fields.items(): |
| 101 | + if param_type not in (str, int, float, bool): |
| 102 | + raise ValueError(f"The command {name} is not compatible with this calling convention (this is not a LLM error, but rather a problem with the capability itself, the parameter {param} is {param_type} and not a simple type (str, int, float, bool))") |
| 103 | + return fields |
| 104 | + |
| 105 | + def parse_params(fields, params) -> tuple[bool, Union[str, Dict[str, Any]]]: |
| 106 | + split_params = params.split(" ", maxsplit=len(fields) - 1) |
| 107 | + if len(split_params) != len(fields): |
| 108 | + return False, "Invalid number of parameters" |
| 109 | + |
| 110 | + parsed_params = dict() |
| 111 | + for param, param_type in fields.items(): |
| 112 | + try: |
| 113 | + parsed_params[param] = param_type(split_params.pop(0)) |
| 114 | + except ValueError as e: |
| 115 | + return False, f"Could not parse parameter {param}: {e}" |
| 116 | + return True, parsed_params |
| 117 | + |
| 118 | + capability_descriptions = dict() |
| 119 | + capability_params = dict() |
| 120 | + for capability_name, capability in capabilities.items(): |
| 121 | + fields = get_simple_fields(capability.__call__, capability_name) |
| 122 | + |
| 123 | + description = f"`{capability_name}" |
| 124 | + if len(fields) > 0: |
| 125 | + description += " " + " ".join(param for param in fields) |
| 126 | + description += "`" |
| 127 | + if include_description: |
| 128 | + description += f": {capability.describe()}" |
| 129 | + |
| 130 | + capability_descriptions[capability_name] = description |
| 131 | + capability_params[capability_name] = fields |
| 132 | + |
| 133 | + def parser(text: str) -> SimpleTextHandlerResult: |
| 134 | + capability_name_and_params = text.split(" ", maxsplit=1) |
| 135 | + if len(capability_name_and_params) == 1: |
| 136 | + capability_name = capability_name_and_params[0] |
| 137 | + params = "" |
| 138 | + else: |
| 139 | + capability_name, params = capability_name_and_params |
| 140 | + if capability_name not in capabilities: |
| 141 | + return False, "Unknown command" |
| 142 | + |
| 143 | + success, parsing_result = parse_params(capability_params[capability_name], params) |
| 144 | + if not success: |
| 145 | + return False, parsing_result |
| 146 | + |
| 147 | + return True, (capability_name, params, capabilities[capability_name](**parsing_result)) |
| 148 | + |
| 149 | + resolved_parser: SimpleTextHandler = parser |
| 150 | + |
| 151 | + if default_capability is not None: |
| 152 | + default_fields = get_simple_fields(default_capability.__call__, "__default__") |
| 153 | + |
| 154 | + def default_capability_parser(text: str) -> SimpleTextHandlerResult: |
| 155 | + success, *output = parser(text) |
| 156 | + if success: |
| 157 | + return success, *output |
| 158 | + |
| 159 | + params = text |
| 160 | + success, parsing_result = parse_params(default_fields, params) |
| 161 | + if not success: |
| 162 | + params = text.split(" ", maxsplit=1)[1] |
| 163 | + success, parsing_result = parse_params(default_fields, params) |
| 164 | + if not success: |
| 165 | + return False, parsing_result |
| 166 | + |
| 167 | + return True, (capability_name, params, default_capability(**parsing_result)) |
| 168 | + |
| 169 | + |
| 170 | + resolved_parser = default_capability_parser |
| 171 | + |
| 172 | + return capability_descriptions, resolved_parser |
0 commit comments