Skip to content

Commit a6c4d9a

Browse files
Add provider configuration support to create session (#25)
* Add provider configuration support to create session * Add tests similar to existing ones * Factor out common logic * Reformat --------- Co-authored-by: Steve Sanderson <SteveSandersonMS@users.noreply.github.com>
1 parent d6e4e33 commit a6c4d9a

3 files changed

Lines changed: 55 additions & 9 deletions

File tree

python/copilot/client.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,11 @@ async def create_session(self, config: Optional[SessionConfig] = None) -> Copilo
242242
if streaming is not None:
243243
payload["streaming"] = streaming
244244

245+
# Add provider configuration if provided
246+
provider = cfg.get("provider")
247+
if provider:
248+
payload["provider"] = self._convert_provider_to_wire_format(provider)
249+
245250
if not self._client:
246251
raise RuntimeError("Client not connected")
247252
response = await self._client.request("session.create", payload)
@@ -293,15 +298,7 @@ async def resume_session(
293298

294299
provider = cfg.get("provider")
295300
if provider:
296-
# Convert snake_case to camelCase for the wire format
297-
wire_provider = {"type": provider.get("type")}
298-
if "base_url" in provider:
299-
wire_provider["baseUrl"] = provider["base_url"]
300-
if "api_key" in provider:
301-
wire_provider["apiKey"] = provider["api_key"]
302-
if "wire_api" in provider:
303-
wire_provider["wireApi"] = provider["wire_api"]
304-
payload["provider"] = wire_provider
301+
payload["provider"] = self._convert_provider_to_wire_format(provider)
305302

306303
# Add streaming option if provided
307304
streaming = cfg.get("streaming")
@@ -339,6 +336,26 @@ async def ping(self, message: Optional[str] = None) -> dict:
339336

340337
return await self._client.request("ping", {"message": message})
341338

339+
def _convert_provider_to_wire_format(self, provider: Dict[str, Any]) -> Dict[str, Any]:
340+
"""Convert provider config from snake_case to camelCase wire format."""
341+
wire_provider: Dict[str, Any] = {"type": provider.get("type")}
342+
if "base_url" in provider:
343+
wire_provider["baseUrl"] = provider["base_url"]
344+
if "api_key" in provider:
345+
wire_provider["apiKey"] = provider["api_key"]
346+
if "wire_api" in provider:
347+
wire_provider["wireApi"] = provider["wire_api"]
348+
if "bearer_token" in provider:
349+
wire_provider["bearerToken"] = provider["bearer_token"]
350+
if "azure" in provider:
351+
azure = provider["azure"]
352+
wire_azure: Dict[str, Any] = {}
353+
if "api_version" in azure:
354+
wire_azure["apiVersion"] = azure["api_version"]
355+
if wire_azure:
356+
wire_provider["azure"] = wire_azure
357+
return wire_provider
358+
342359
async def _start_cli_server(self) -> None:
343360
"""Start the CLI server process"""
344361
cli_path = self.options["cli_path"]

python/copilot/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ class SessionConfig(TypedDict, total=False):
128128
available_tools: list[str]
129129
# List of tool names to disable (ignored if available_tools is set)
130130
excluded_tools: list[str]
131+
# Custom provider configuration (BYOK - Bring Your Own Key)
132+
provider: ProviderConfig
131133
# Enable streaming of assistant message and reasoning chunks
132134
# When True, assistant.message_delta and assistant.reasoning_delta events
133135
# with delta_content are sent as the response is generated

python/e2e/test_session.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,33 @@ def get_secret_number_handler(invocation):
208208
answer = await get_final_assistant_message(session)
209209
assert "54321" in answer.data.content
210210

211+
async def test_should_create_session_with_custom_provider(self, ctx: E2ETestContext):
212+
session = await ctx.client.create_session(
213+
{
214+
"provider": {
215+
"type": "openai",
216+
"base_url": "https://api.openai.com/v1",
217+
"api_key": "fake-key",
218+
}
219+
}
220+
)
221+
assert session.session_id
222+
223+
async def test_should_create_session_with_azure_provider(self, ctx: E2ETestContext):
224+
session = await ctx.client.create_session(
225+
{
226+
"provider": {
227+
"type": "azure",
228+
"base_url": "https://my-resource.openai.azure.com",
229+
"api_key": "fake-key",
230+
"azure": {
231+
"api_version": "2024-02-15-preview",
232+
},
233+
}
234+
}
235+
)
236+
assert session.session_id
237+
211238
async def test_should_resume_session_with_custom_provider(self, ctx: E2ETestContext):
212239
session = await ctx.client.create_session()
213240
session_id = session.session_id

0 commit comments

Comments
 (0)