Skip to content

Commit a92428a

Browse files
authored
Merge pull request #61 from ipa-lab/fixes
Fixes
2 parents 8c32520 + d1f4ab8 commit a92428a

11 files changed

Lines changed: 132 additions & 26 deletions

File tree

capabilities/capability.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import abc
22
import inspect
3-
from typing import Union, Type, Dict
3+
from typing import Union, Type, Dict, Callable, Any
44

55
from pydantic import create_model, BaseModel
66

@@ -78,3 +78,95 @@ class Model(Action):
7878

7979
return Model
8080

81+
82+
SimpleTextHandlerResult = tuple[bool, Union[str, tuple[str, str, ...]]]
83+
SimpleTextHandler = Callable[[str], SimpleTextHandlerResult]
84+
85+
86+
def capabilities_to_simple_text_handler(capabilities: Dict[str, Capability], default_capability: Capability = None, include_description: bool = True) -> tuple[Dict[str, str], SimpleTextHandler]:
87+
"""
88+
This function generates a simple text handler from a set of capabilities.
89+
It is to be used when no function calling is available, and structured output is not to be trusted, which is why it
90+
only supports the most basic of parameter types for the capabilities (str, int, float, bool).
91+
92+
As result it returns a dictionary of capability names to their descriptions and a parser function that can be used
93+
to parse the text input and execute it. The first return value of the parser function is a boolean indicating
94+
whether the parsing was successful, the second return value is a tuple containing the capability name, the parameters
95+
as a string and the result of the capability execution.
96+
"""
97+
def get_simple_fields(func, name) -> Dict[str, Type]:
98+
sig = inspect.signature(func)
99+
fields = {param: param_info.annotation for param, param_info in sig.parameters.items()}
100+
for param, param_type in fields.items():
101+
if param_type not in (str, int, float, bool):
102+
raise ValueError(f"The command {name} is not compatible with this calling convention (this is not a LLM error, but rather a problem with the capability itself, the parameter {param} is {param_type} and not a simple type (str, int, float, bool))")
103+
return fields
104+
105+
def parse_params(fields, params) -> tuple[bool, Union[str, Dict[str, Any]]]:
106+
split_params = params.split(" ", maxsplit=len(fields) - 1)
107+
if len(split_params) != len(fields):
108+
return False, "Invalid number of parameters"
109+
110+
parsed_params = dict()
111+
for param, param_type in fields.items():
112+
try:
113+
parsed_params[param] = param_type(split_params.pop(0))
114+
except ValueError as e:
115+
return False, f"Could not parse parameter {param}: {e}"
116+
return True, parsed_params
117+
118+
capability_descriptions = dict()
119+
capability_params = dict()
120+
for capability_name, capability in capabilities.items():
121+
fields = get_simple_fields(capability.__call__, capability_name)
122+
123+
description = f"`{capability_name}"
124+
if len(fields) > 0:
125+
description += " " + " ".join(param for param in fields)
126+
description += "`"
127+
if include_description:
128+
description += f": {capability.describe()}"
129+
130+
capability_descriptions[capability_name] = description
131+
capability_params[capability_name] = fields
132+
133+
def parser(text: str) -> SimpleTextHandlerResult:
134+
capability_name_and_params = text.split(" ", maxsplit=1)
135+
if len(capability_name_and_params) == 1:
136+
capability_name = capability_name_and_params[0]
137+
params = ""
138+
else:
139+
capability_name, params = capability_name_and_params
140+
if capability_name not in capabilities:
141+
return False, "Unknown command"
142+
143+
success, parsing_result = parse_params(capability_params[capability_name], params)
144+
if not success:
145+
return False, parsing_result
146+
147+
return True, (capability_name, params, capabilities[capability_name](**parsing_result))
148+
149+
resolved_parser: SimpleTextHandler = parser
150+
151+
if default_capability is not None:
152+
default_fields = get_simple_fields(default_capability.__call__, "__default__")
153+
154+
def default_capability_parser(text: str) -> SimpleTextHandlerResult:
155+
success, *output = parser(text)
156+
if success:
157+
return success, *output
158+
159+
params = text
160+
success, parsing_result = parse_params(default_fields, params)
161+
if not success:
162+
params = text.split(" ", maxsplit=1)[1]
163+
success, parsing_result = parse_params(default_fields, params)
164+
if not success:
165+
return False, parsing_result
166+
167+
return True, (capability_name, params, default_capability(**parsing_result))
168+
169+
170+
resolved_parser = default_capability_parser
171+
172+
return capability_descriptions, resolved_parser

capabilities/psexec_test_credential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class PSExecTestCredential(Capability):
1111
conn: PSExecConnection
1212

1313
def describe(self) -> str:
14-
return f"give credentials to be tested by stating `{self.get_name()} username password`"
14+
return f"give credentials to be tested"
1515

1616
def get_name(self) -> str:
1717
return "test_credential"

capabilities/ssh_run_command.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,21 @@
1818
@dataclass
1919
class SSHRunCommand(Capability):
2020
conn: SSHConnection
21+
timeout: int = 10
2122

2223
def describe(self) -> str:
23-
return f"give a command to be executed on the shell and I will respond with the terminal output when running this command on the linux server. The given command must not require user interaction. Only state the to be executed command. The command should be used for enumeration or privilege escalation."
24+
return f"give a command to be executed and I will respond with the terminal output when running this command over SSH on the linux machine. The given command must not require user interaction."
2425

25-
def __call__(self, command: str, timeout:int=10) -> Tuple[str, bool]:
26+
def get_name(self):
27+
return "exec_command"
28+
29+
def __call__(self, command: str) -> Tuple[str, bool]:
2630
got_root = False
31+
32+
if command.startswith(self.get_name()):
33+
cmd_parts = command.split(" ", 1)
34+
command = cmd_parts[1]
35+
2736
sudo_pass = Responder(
2837
pattern=r'\[sudo\] password for ' + self.conn.username + ':',
2938
response=self.conn.password + '\n',
@@ -32,7 +41,7 @@ def __call__(self, command: str, timeout:int=10) -> Tuple[str, bool]:
3241
out = StringIO()
3342

3443
try:
35-
resp = self.conn.run(command, pty=True, warn=True, out_stream=out, watchers=[sudo_pass], timeout=timeout)
44+
resp = self.conn.run(command, pty=True, warn=True, out_stream=out, watchers=[sudo_pass], timeout=self.timeout)
3645
except Exception as e:
3746
print("TIMEOUT! Could we have become root?")
3847
out.seek(0)

capabilities/ssh_test_credential.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,13 @@ class SSHTestCredential(Capability):
1212
conn: SSHConnection
1313

1414
def describe(self) -> str:
15-
return f"give credentials to be tested by stating `{self.get_name()} username password`"
15+
return f"give credentials to be tested"
1616

1717
def get_name(self):
1818
return "test_credential"
1919

20-
def __call__(self, command: str) -> Tuple[str, bool]:
21-
cmd_parts = command.split(" ")
22-
assert (cmd_parts[0] == "test_credential")
23-
24-
if len(cmd_parts) != 3:
25-
return "didn't provide username/password", False
26-
27-
test_conn = self.conn.new_with(username=cmd_parts[1], password=cmd_parts[2])
20+
def __call__(self, username: str, password: str) -> Tuple[str, bool]:
21+
test_conn = self.conn.new_with(username=username, password=password)
2822
try:
2923
test_conn.init()
3024
user = test_conn.run("whoami")[0].strip('\n\r ')

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ tiktoken==0.6.0
2727
urllib3==2.2.1
2828
wrapt==1.16.0
2929
instructor==1.2.2
30+
PyYAML==6.0.1

usecases/agents.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
1+
from abc import ABC
12
from dataclasses import dataclass, field
23
from typing import Dict
34

4-
from capabilities.capability import Capability
5+
from capabilities.capability import Capability, capabilities_to_simple_text_handler
56
from usecases.common_patterns import RoundBasedUseCase
67

78

89
@dataclass
9-
class Agent(RoundBasedUseCase):
10-
10+
class Agent(RoundBasedUseCase, ABC):
1111
_capabilities: Dict[str, Capability] = field(default_factory=dict)
1212
_default_capability: Capability = None
1313

1414
def init(self):
1515
super().init()
1616

17-
def add_capability(self, cap:Capability, default:bool=False):
17+
def add_capability(self, cap: Capability, default: bool = False):
1818
self._capabilities[cap.get_name()] = cap
1919
if default:
2020
self._default_capability = cap
2121

22-
def get_capability(self, name:str) -> Capability:
22+
def get_capability(self, name: str) -> Capability:
2323
return self._capabilities.get(name, self._default_capability)
2424

2525
def get_capability_block(self) -> str:
26-
return "You can either\n\n" + "\n".join(map(lambda i: f"- {i.describe()}", self._capabilities.values()))
26+
capability_descriptions, _parser = capabilities_to_simple_text_handler(self._capabilities)
27+
return "You can either\n\n" + "\n".join(f"- {description}" for description in capability_descriptions.values())

usecases/privesc/common.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from rich.panel import Panel
77

88
from capabilities import Capability
9+
from capabilities.capability import capabilities_to_simple_text_handler
910
from usecases.agents import Agent
1011
from utils import llm_util, ui
1112
from utils.cli_history import SlidingCliHistory
@@ -48,12 +49,20 @@ def perform_round(self, turn):
4849

4950
with self.console.status("[bold green]Executing that command..."):
5051
self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:"))
51-
result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd)
52+
_capability_descriptions, parser = capabilities_to_simple_text_handler(self._capabilities, default_capability=self._default_capability)
53+
success, *output = parser(cmd)
54+
if not success:
55+
self.console.print(Panel(output[0], title=f"[bold red]Error parsing command:"))
56+
return False
57+
58+
assert(len(output) == 1)
59+
capability, cmd, (result, got_root) = output[0]
5260

5361
# log and output the command and its result
5462
self.log_db.add_log_query(self._run_id, turn, cmd, result, answer)
5563
if self._sliding_history:
5664
self._sliding_history.add_command(cmd, result)
65+
5766
self.console.print(Panel(result, title=f"[bold cyan]{cmd}"))
5867

5968
# analyze the result..

usecases/privesc/linux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def read_hint(self):
104104

105105
run_cmd = "wget -q 'https://github.com/diego-treitos/linux-smart-enumeration/releases/latest/download/lse.sh' -O lse.sh;chmod 700 lse.sh; ./lse.sh -c -i -l 0 | grep -v 'nope$' | grep -v 'skip$'"
106106

107-
result, got_root = SSHRunCommand(conn=self.conn)(run_cmd, timeout=120)
107+
result, got_root = SSHRunCommand(conn=self.conn, timeout=120)(run_cmd)
108108

109109
self.console.print("[yellow]got the output: " + result)
110110
cmd = self.llm.get_response(template_lse, lse_output=result, number=3)

usecases/privesc/templates/query_next_command.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ You currently expect the following about the target system:
1818
${state}
1919
%endif
2020
% if hint:
21-
You are provided a hint: ${hint}
21+
You are provided the following guidance: ${hint}
2222
%endif
2323

24-
Give your command. Do not add any explanation or add an initial `$`.
24+
State your command. You should focus upon enumeration and privilege escalation. Do not add any explanation or add an initial `$`.

utils/configurable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def get_parameters(fun, basename: str, fields: Dict[str, dataclasses.Field] = No
114114

115115
if hasattr(type, "__parameters__"):
116116
params[name] = ComplexParameterDefinition(name, type, default, description, get_class_parameters(type, f"{basename}.{fun.__name__}"))
117-
elif type in (str, int, bool):
117+
elif type in (str, int, float, bool):
118118
params[name] = ParameterDefinition(name, type, default, description)
119119
else:
120120
raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have str, int, bool, or a __parameters__ class as type, not {type}")

0 commit comments

Comments
 (0)