diff --git a/docs/USAGE.md b/docs/USAGE.md index 0dc264e9b..0416485ef 100644 --- a/docs/USAGE.md +++ b/docs/USAGE.md @@ -15,7 +15,7 @@ agents with specialized capabilities and tools. It features: - **📦 Agent distribution** via Docker registry integration - **🔒 Security-first design** with proper client scoping and resource isolation - **⚡ Event-driven streaming** for real-time interactions -- **🧠 Multi-model support** (OpenAI, Anthropic, Gemini, [Docker Model Runner (DMR)](https://docs.docker.com/ai/model-runner/)) +- **🧠 Multi-model support** (OpenAI, Anthropic, Gemini, Amazon Bedrock, [Docker Model Runner (DMR)](https://docs.docker.com/ai/model-runner/)) ## Why? @@ -193,7 +193,7 @@ cagent run ./agent.yaml /analyze | Property | Type | Description | Required | |---------------------|------------|------------------------------------------------------------------------------|----------| -| `provider` | string | Provider: `openai`, `anthropic`, `google`, `dmr` | ✓ | +| `provider` | string | Provider: `openai`, `anthropic`, `google`, `amazon-bedrock`, `dmr` | ✓ | | `model` | string | Model name (e.g., `gpt-4o`, `claude-sonnet-4-0`, `gemini-2.5-flash`) | ✓ | | `temperature` | float | Randomness (0.0-1.0) | ✗ | | `max_tokens` | integer | Response length limit | ✗ | @@ -208,8 +208,8 @@ cagent run ./agent.yaml /analyze ```yaml models: model_name: - provider: string # Provider: openai, anthropic, google, dmr - model: string # Model name: gpt-4o, claude-3-7-sonnet-latest, gemini-2.5-flash, qwen3:4B, ... + provider: string # Provider: openai, anthropic, google, amazon-bedrock, dmr + model: string # Model name: gpt-4o, claude-3-7-sonnet-latest, anthropic.claude-3-5-sonnet-20241022-v2:0, gemini-2.5-flash, qwen3:4B, ... temperature: float # Randomness (0.0-1.0) max_tokens: integer # Response length limit top_p: float # Nucleus sampling (0.0-1.0) @@ -334,6 +334,12 @@ models: provider: google model: gemini-2.5-flash +# Amazon Bedrock +models: + bedrock-claude: + provider: amazon-bedrock + model: anthropic.claude-3-5-sonnet-20241022-v2:0 + # Docker Model Runner (DMR) models: qwen: @@ -461,6 +467,57 @@ These options work alongside `max_tokens` (which sets `--context-size`) and `run - Endpoint empty in status: ensure the Model Runner is running, or set `base_url` manually - Flag parsing: if using a single string, quote properly in YAML; you can also use a list +#### Amazon Bedrock provider usage + +The `amazon-bedrock` provider enables access to various AI models hosted on AWS Bedrock, including Anthropic Claude, Amazon Titan, Meta Llama, and Mistral models. + +**Authentication:** + +The Bedrock provider supports two authentication methods: + +1. **Bearer Token** (via `AWS_BEDROCK_TOKEN` environment variable): + - Use for custom authentication services or proxy scenarios + - If set, the provider uses Bearer token authentication and skips AWS Signature v4 signing + - Example: `export AWS_BEDROCK_TOKEN="your-token-here"` + +2. **Standard AWS Credentials** (default chain): + - Environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_SESSION_TOKEN` + - AWS profile: `AWS_PROFILE` environment variable or default profile in `~/.aws/credentials` + - IAM role (for EC2/ECS/Lambda environments) + +**Configuration:** + +Basic configuration: + +```yaml +models: + bedrock-claude: + provider: amazon-bedrock + model: anthropic.claude-3-5-sonnet-20241022-v2:0 + temperature: 0.7 + max_tokens: 4000 +``` + +With custom region: + +```yaml +models: + bedrock-claude: + provider: amazon-bedrock + model: anthropic.claude-3-5-sonnet-20241022-v2:0 + provider_opts: + region: us-west-2 # Optional, defaults to AWS_REGION or us-east-1 +``` + +With custom endpoint (for VPC endpoints or proxies): + +```yaml +models: + bedrock-claude: + provider: amazon-bedrock + model: anthropic.claude-3-5-sonnet-20241022-v2:0 + base_url: https://bedrock-runtime.us-east-1.amazonaws.com +``` ### Alloy models diff --git a/examples/bedrock.yaml b/examples/bedrock.yaml new file mode 100644 index 000000000..3a09ba9aa --- /dev/null +++ b/examples/bedrock.yaml @@ -0,0 +1,44 @@ +version: "2" + +agents: + root: + model: claude + description: "Test multi-tool agent" + instruction: | + You are a test agent that demonstrates multi-tool calling. + + ⚠️ CRITICAL INSTRUCTIONS - YOU MUST FOLLOW THESE EXACTLY: + + When the user asks you to call multiple tools: + 1. Call the FIRST tool and wait for its result + 2. After receiving the first result, you MUST call the SECOND tool + 3. After receiving the second result, ONLY THEN respond to the user + + DO NOT respond with text after calling just one tool. + DO NOT say "I'll call the tools" - JUST CALL THEM. + DO NOT provide your final answer until you have called ALL required tools. + + If the task requires N tools, you must make N separate tool calls. + Each tool call must be followed by receiving its result before the next tool call. + + Example correct sequence: + - Tool Call 1 → Receive Result 1 + - Tool Call 2 → Receive Result 2 + - Final Response with both results + + WRONG (do not do this): + - Tool Call 1 → Receive Result 1 → Final Response (MISSING TOOL 2!) + + max_iterations: 20 + toolsets: + - type: shell + +models: + claude: + provider: amazon-bedrock + model: global.anthropic.claude-haiku-4-5-20251001-v1:0 + temperature: 0.7 + max_tokens: 4000 + # Optional: specify AWS region (defaults to AWS_REGION env var or us-east-1) + # provider_opts: + # region: us-west-2 diff --git a/go.mod b/go.mod index 20baca790..3498852c2 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,10 @@ require ( github.com/alpkeskin/gotoon v0.1.1 github.com/anthropics/anthropic-sdk-go v1.19.0 github.com/atotto/clipboard v0.1.4 + github.com/aws/aws-sdk-go-v2 v1.39.3 + github.com/aws/aws-sdk-go-v2/config v1.31.13 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.41.1 + github.com/aws/smithy-go v1.23.1 github.com/aymanbagabas/go-udiff v0.3.1 github.com/bmatcuk/doublestar/v4 v4.9.1 github.com/charmbracelet/glamour/v2 v2.0.0-20251106195642-800eb8175930 @@ -61,6 +65,17 @@ require ( cloud.google.com/go/compute/metadata v0.9.0 // indirect dario.cat/mergo v1.0.2 // indirect github.com/JohannesKaufmann/dom v0.2.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.17 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 // indirect github.com/ProtonMail/go-crypto v1.1.6 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect diff --git a/go.sum b/go.sum index 036c503c5..ddc40a2ad 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,36 @@ github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPd github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= +github.com/aws/aws-sdk-go-v2 v1.39.3 h1:h7xSsanJ4EQJXG5iuW4UqgP7qBopLpj84mpkNx3wPjM= +github.com/aws/aws-sdk-go-v2 v1.39.3/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2 h1:t9yYsydLYNBk9cJ73rgPhPWqOh/52fcWDQB5b1JsKSY= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2/go.mod h1:IusfVNTmiSN3t4rhxWFaBAqn+mcNdwKtPcV16eYdgko= +github.com/aws/aws-sdk-go-v2/config v1.31.13 h1:wcqQB3B0PgRPUF5ZE/QL1JVOyB0mbPevHFoAMpemR9k= +github.com/aws/aws-sdk-go-v2/config v1.31.13/go.mod h1:ySB5D5ybwqGbT6c3GszZ+u+3KvrlYCUQNo62+hkKOFk= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17 h1:skpEwzN/+H8cdrrtT8y+rvWJGiWWv0DeNAe+4VTf+Vs= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17/go.mod h1:Ed+nXsaYa5uBINovJhcAWkALvXw2ZLk36opcuiSZfJM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 h1:UuGVOX48oP4vgQ36oiKmW9RuSeT8jlgQgBFQD+HUiHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10/go.mod h1:vM/Ini41PzvudT4YkQyE/+WiQJiQ6jzeDyU8pQKwCac= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 h1:mj/bdWleWEh81DtpdHKkw41IrS+r3uw1J/VQtbwYYp8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10/go.mod h1:7+oEMxAZWP8gZCyjcm9VicI0M61Sx4DJtcGfKYv2yKQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 h1:wh+/mn57yhUrFtLIxyFPh2RgxgQz/u+Yrf7hiHGHqKY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10/go.mod h1:7zirD+ryp5gitJJ2m1BBux56ai8RIRDykXZrJSp540w= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.41.1 h1:sscdABXtedWQ+5I0YnxawJwrX1YzbPhIs7TklRaRDpk= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.41.1/go.mod h1:LVJ9jAJ1nuUyhovH5z7GAA/FktQOMarcZgGeqiHQJPo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 h1:xtuxji5CS0JknaXoACOunXOYOQzgfTvGAc9s2QdCJA4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2/go.mod h1:zxwi0DIR0rcRcgdbl7E2MSOvxDyyXGBlScvBkARFaLQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 h1:DRND0dkCKtJzCj4Xl4OpVbXZgfttY5q712H9Zj7qc/0= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10/go.mod h1:tGGNmJKOTernmR2+VJ0fCzQRurcPZj9ut60Zu5Fi6us= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 h1:fspVFg6qMx0svs40YgRmE7LZXh9VRZvTT35PfdQR6FM= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7/go.mod h1:BQTKL3uMECaLaUV3Zc2L4Qybv8C6BIXjuu1dOPyxTQs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 h1:scVnW+NLXasGOhy7HhkdT9AGb6kjgW7fJ5xYkUaqHs0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2/go.mod h1:FRNCY3zTEWZXBKm2h5UBUPvCVDOecTad9KhynDyGBc0= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 h1:VEO5dqFkMsl8QZ2yHsFDJAIZLAkEbaYDB+xdKi0Feic= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7/go.mod h1:L1xxV3zAdB+qVrVW/pBIrIAnHFWHo6FBbFe4xOGsG/o= +github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M= +github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3vj1nolY= github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl8ZBcNLgcbrw8E= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index 0d25e9816..936bed484 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -103,6 +103,11 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro return nil, errors.New("model type must be 'anthropic'") } + // Ensure env is not nil - use default provider if nil + if env == nil { + env = environment.NewDefaultProvider() + } + var globalOptions options.ModelOptions for _, opt := range opts { opt(&globalOptions) diff --git a/pkg/model/provider/bedrock/adapter.go b/pkg/model/provider/bedrock/adapter.go new file mode 100644 index 000000000..6e7f68c1e --- /dev/null +++ b/pkg/model/provider/bedrock/adapter.go @@ -0,0 +1,218 @@ +package bedrock + +import ( + "fmt" + "io" + "log/slog" + + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + + "github.com/docker/cagent/pkg/chat" + "github.com/docker/cagent/pkg/tools" +) + +// StreamAdapter adapts the Bedrock Converse stream to chat.MessageStream interface +type StreamAdapter struct { + stream *bedrockruntime.ConverseStreamOutput + model string + eventStream <-chan types.ConverseStreamOutput + toolCallData map[int]*toolCallInfo // Track tool call data by index +} + +// toolCallInfo holds information about a tool call being streamed +type toolCallInfo struct { + ID string + Name string + Arguments string +} + +// newStreamAdapter creates a new stream adapter for Converse API +func newStreamAdapter(output *bedrockruntime.ConverseStreamOutput, model string) *StreamAdapter { + return &StreamAdapter{ + stream: output, + model: model, + eventStream: output.GetStream().Events(), + toolCallData: make(map[int]*toolCallInfo), + } +} + +// Recv gets the next completion chunk from the Converse API stream +func (a *StreamAdapter) Recv() (chat.MessageStreamResponse, error) { + if a.eventStream == nil { + return chat.MessageStreamResponse{}, io.EOF + } + + event, ok := <-a.eventStream + if !ok { + // Stream closed + return chat.MessageStreamResponse{}, io.EOF + } + + return a.processConverseEvent(event) +} + +// processConverseEvent processes a Converse API stream event +func (a *StreamAdapter) processConverseEvent(event types.ConverseStreamOutput) (chat.MessageStreamResponse, error) { + response := chat.MessageStreamResponse{ + Model: a.model, + Choices: []chat.MessageStreamChoice{ + { + Index: 0, + Delta: chat.MessageDelta{ + Role: string(chat.MessageRoleAssistant), + }, + }, + }, + } + + switch e := event.(type) { + case *types.ConverseStreamOutputMemberMessageStart: + // Message start event - provides role + slog.Debug("Converse MessageStart event", "role", e.Value.Role) + return response, nil + + case *types.ConverseStreamOutputMemberContentBlockStart: + // Content block start - may be text or tool use + if e.Value.Start != nil { + switch start := e.Value.Start.(type) { + case *types.ContentBlockStartMemberToolUse: + // Tool use started - check for nil values before dereferencing + if start.Value.ToolUseId == nil || start.Value.Name == nil { + slog.Warn("Converse ContentBlockStart (ToolUse) missing required fields", + "tool_use_id_nil", start.Value.ToolUseId == nil, + "name_nil", start.Value.Name == nil) + return response, nil + } + + toolCall := tools.ToolCall{ + ID: *start.Value.ToolUseId, + Type: "function", + Function: tools.FunctionCall{ + Name: *start.Value.Name, + }, + } + response.Choices[0].Delta.ToolCalls = []tools.ToolCall{toolCall} + + // Store tool call info for delta events + if e.Value.ContentBlockIndex != nil { + a.toolCallData[int(*e.Value.ContentBlockIndex)] = &toolCallInfo{ + ID: *start.Value.ToolUseId, + Name: *start.Value.Name, + } + } + + slog.Debug("Converse ContentBlockStart (ToolUse)", + "tool_id", *start.Value.ToolUseId, + "tool_name", *start.Value.Name, + "index", e.Value.ContentBlockIndex) + } + } + return response, nil + + case *types.ConverseStreamOutputMemberContentBlockDelta: + // Content block delta - streaming content + if e.Value.Delta != nil { + switch delta := e.Value.Delta.(type) { + case *types.ContentBlockDeltaMemberText: + // Text content delta + response.Choices[0].Delta.Content = delta.Value + slog.Debug("Converse ContentBlockDelta (Text)", "length", len(delta.Value)) + + case *types.ContentBlockDeltaMemberToolUse: + // Tool use input delta (streaming JSON arguments) + // Accumulate but DON'T send to runtime until complete + if e.Value.ContentBlockIndex != nil && delta.Value.Input != nil { + if toolInfo, ok := a.toolCallData[int(*e.Value.ContentBlockIndex)]; ok { + toolInfo.Arguments += *delta.Value.Input + slog.Debug("Converse ContentBlockDelta (ToolUse) accumulated", + "chunk_length", len(*delta.Value.Input), + "total_length", len(toolInfo.Arguments)) + } + } + } + } + return response, nil + + case *types.ConverseStreamOutputMemberContentBlockStop: + // Content block stopped - now send complete tool call if this was a tool use block + if e.Value.ContentBlockIndex != nil { + if toolInfo, ok := a.toolCallData[int(*e.Value.ContentBlockIndex)]; ok { + // Ensure arguments is valid JSON - if empty, use empty object + args := toolInfo.Arguments + if args == "" { + args = "{}" + } + + // Send the complete tool call now + toolCall := tools.ToolCall{ + ID: toolInfo.ID, + Type: "function", + Function: tools.FunctionCall{ + Name: toolInfo.Name, + Arguments: args, + }, + } + response.Choices[0].Delta.ToolCalls = []tools.ToolCall{toolCall} + slog.Debug("Converse ContentBlockStop - sending complete tool call", + "tool_id", toolInfo.ID, + "tool_name", toolInfo.Name, + "args_length", len(args)) + } + } + slog.Debug("Converse ContentBlockStop", "index", e.Value.ContentBlockIndex) + return response, nil + + case *types.ConverseStreamOutputMemberMessageStop: + // Message stopped - provides stop reason + if e.Value.StopReason != "" { + response.Choices[0].FinishReason = mapConverseStopReason(e.Value.StopReason) + slog.Debug("Converse MessageStop", "stop_reason", e.Value.StopReason) + } + return response, nil + + case *types.ConverseStreamOutputMemberMetadata: + // Metadata event - provides token usage + if e.Value.Usage != nil { + usage := &chat.Usage{} + if e.Value.Usage.InputTokens != nil { + usage.InputTokens = int64(*e.Value.Usage.InputTokens) + } + if e.Value.Usage.OutputTokens != nil { + usage.OutputTokens = int64(*e.Value.Usage.OutputTokens) + } + response.Usage = usage + slog.Debug("Converse Metadata", "input_tokens", usage.InputTokens, "output_tokens", usage.OutputTokens) + } + return response, nil + + default: + slog.Warn("Unexpected Converse stream event", "type", fmt.Sprintf("%T", event)) + return chat.MessageStreamResponse{}, fmt.Errorf("unexpected stream event: %T", event) + } +} + +// mapConverseStopReason maps Converse API stop reasons to standard finish reasons +func mapConverseStopReason(reason types.StopReason) chat.FinishReason { + switch reason { + case types.StopReasonEndTurn: + return chat.FinishReasonStop + case types.StopReasonMaxTokens: + return chat.FinishReasonLength + case types.StopReasonToolUse: + return chat.FinishReasonToolCalls + case types.StopReasonStopSequence: + return chat.FinishReasonStop + case types.StopReasonContentFiltered: + return chat.FinishReasonContentFilter + default: + slog.Warn("Unknown stop reason", "reason", reason) + return chat.FinishReasonStop + } +} + +// Close closes the stream +func (a *StreamAdapter) Close() { + // The event channel will be closed by the SDK when the stream ends + // We don't need to do anything here +} diff --git a/pkg/model/provider/bedrock/client.go b/pkg/model/provider/bedrock/client.go new file mode 100644 index 000000000..b9f9c5db6 --- /dev/null +++ b/pkg/model/provider/bedrock/client.go @@ -0,0 +1,396 @@ +// Package bedrock implements the AWS Bedrock provider for cagent. +// It supports multiple model families including Anthropic Claude, Amazon Titan, +// Meta Llama, and Mistral models via AWS Bedrock Runtime API. +// Authentication can be done via AWS credentials or bearer token. +package bedrock + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + + "github.com/docker/cagent/pkg/chat" + "github.com/docker/cagent/pkg/config/latest" + "github.com/docker/cagent/pkg/environment" + "github.com/docker/cagent/pkg/model/provider/base" + "github.com/docker/cagent/pkg/model/provider/options" + cagentTools "github.com/docker/cagent/pkg/tools" +) + +// bearerTokenTransport wraps an http.RoundTripper and adds Bearer token authentication +type bearerTokenTransport struct { + token string + transport http.RoundTripper +} + +func (t *bearerTokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone the request to avoid modifying the original + clonedReq := req.Clone(req.Context()) + clonedReq.Header.Set("Authorization", "Bearer "+t.token) + return t.transport.RoundTrip(clonedReq) +} + +// Client represents a Bedrock client wrapper implementing provider.Provider +type Client struct { + base.Config + client *bedrockruntime.Client + region string +} + +// NewClient creates a new Bedrock client from the provided configuration +func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (*Client, error) { + if cfg == nil { + slog.Error("Bedrock client creation failed", "error", "model configuration is required") + return nil, errors.New("model configuration is required") + } + + if cfg.Provider != "amazon-bedrock" { + slog.Error("Bedrock client creation failed", "error", "model type must be 'amazon-bedrock'", "actual_type", cfg.Provider) + return nil, errors.New("model type must be 'amazon-bedrock'") + } + + // Ensure env is not nil - use default provider if nil + if env == nil { + env = environment.NewDefaultProvider() + } + + var globalOptions options.ModelOptions + for _, opt := range opts { + opt(&globalOptions) + } + + // Determine region: explicit config takes precedence over environment variable + region := "us-east-1" // default + if envRegion := env.Get(ctx, "AWS_REGION"); envRegion != "" { + region = envRegion + } + // Explicit ProviderOpts config takes precedence over environment variable + if cfg.ProviderOpts != nil { + if r, ok := cfg.ProviderOpts["region"]; ok { + if regionStr, ok := r.(string); ok && regionStr != "" { + region = regionStr + } + } + } + + // Check for bearer token authentication first + bearerToken := env.Get(ctx, "AWS_BEDROCK_TOKEN") + + var awsCfg aws.Config + var err error + + if bearerToken != "" { + slog.Debug("Bedrock using bearer token authentication", "token_present", true) + // For bearer token auth (proxy/gateway scenarios), we provide static credentials + // to satisfy the SDK's auth requirements, but our custom HTTP transport will + // replace the Authorization header with the bearer token. + // The following credentials are AWS documentation example credentials (see: + // https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html#access-keys-and-secret-access-keys). + // They are intentionally fake and used only to satisfy the SDK's authentication requirements + // when using bearer token authentication. The actual authorization is handled by the + // bearerTokenTransport, which replaces the Authorization header with the bearer token. + staticCreds := credentials.NewStaticCredentialsProvider("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", "") + + awsCfg, err = config.LoadDefaultConfig(ctx, + config.WithRegion(region), + config.WithCredentialsProvider(staticCreds), + config.WithHTTPClient(&http.Client{ + Transport: &bearerTokenTransport{ + token: bearerToken, + transport: http.DefaultTransport, + }, + }), + ) + if err != nil { + slog.Error("Failed to load AWS config for bearer token auth", "error", err) + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + } else { + // Use standard AWS credential chain + slog.Debug("Bedrock using AWS credential chain", "region", region) + awsCfg, err = config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + slog.Error("Failed to load AWS config", "error", err) + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + } + + // Build client options + clientOpts := []func(*bedrockruntime.Options){ + func(o *bedrockruntime.Options) { + o.Region = region + }, + } + + // Add custom endpoint if specified + if cfg.BaseURL != "" { + slog.Debug("Bedrock using custom endpoint", "endpoint", cfg.BaseURL) + clientOpts = append(clientOpts, func(o *bedrockruntime.Options) { + o.BaseEndpoint = aws.String(cfg.BaseURL) + }) + } + + client := bedrockruntime.NewFromConfig(awsCfg, clientOpts...) + slog.Debug("Bedrock client created successfully", "model", cfg.Model, "region", region) + + return &Client{ + Config: base.Config{ + ModelConfig: *cfg, + ModelOptions: globalOptions, + Env: env, + }, + client: client, + region: region, + }, nil +} + +// CreateChatCompletionStream creates a streaming chat completion request using Converse API +func (c *Client) CreateChatCompletionStream( + ctx context.Context, + messages []chat.Message, + requestTools []cagentTools.Tool, +) (chat.MessageStream, error) { + slog.Debug("Creating Bedrock chat completion stream", + "model", c.ModelConfig.Model, + "message_count", len(messages), + "tool_count", len(requestTools), + "region", c.region) + + if len(messages) == 0 { + slog.Error("Bedrock stream creation failed", "error", "at least one message is required") + return nil, errors.New("at least one message is required") + } + + // Convert messages to Converse API format + converseMessages, systemBlocks, err := convertToConverseMessages(messages) + if err != nil { + slog.Error("Failed to convert messages", "error", err) + return nil, fmt.Errorf("failed to convert messages: %w", err) + } + + // Build Converse API input + input := &bedrockruntime.ConverseStreamInput{ + ModelId: aws.String(c.ModelConfig.Model), + Messages: converseMessages, + } + + // Add system prompts if present + if len(systemBlocks) > 0 { + input.System = systemBlocks + } + + // Add inference configuration + inferenceConfig := &types.InferenceConfiguration{} + if c.ModelConfig.MaxTokens > 0 { + inferenceConfig.MaxTokens = aws.Int32(int32(c.ModelConfig.MaxTokens)) + } + if c.ModelConfig.Temperature != nil { + inferenceConfig.Temperature = aws.Float32(float32(*c.ModelConfig.Temperature)) + } + if c.ModelConfig.TopP != nil { + inferenceConfig.TopP = aws.Float32(float32(*c.ModelConfig.TopP)) + } + input.InferenceConfig = inferenceConfig + + // Add tools if provided + if len(requestTools) > 0 { + converseTools, err := convertToConverseTools(requestTools) + if err != nil { + slog.Error("Failed to convert tools", "error", err) + return nil, fmt.Errorf("failed to convert tools: %w", err) + } + input.ToolConfig = &types.ToolConfiguration{ + Tools: converseTools, + } + } + + // Invoke model with streaming + output, err := c.client.ConverseStream(ctx, input) + if err != nil { + slog.Error("Bedrock stream creation failed", "error", err, "model", c.ModelConfig.Model) + return nil, fmt.Errorf("failed to invoke model: %w", err) + } + + slog.Debug("Bedrock chat completion stream created successfully", "model", c.ModelConfig.Model) + return newStreamAdapter(output, c.ModelConfig.Model), nil +} + +// convertToConverseMessages converts cagent messages to Converse API format +func convertToConverseMessages(messages []chat.Message) ([]types.Message, []types.SystemContentBlock, error) { + var converseMessages []types.Message + var systemBlocks []types.SystemContentBlock + + for i := 0; i < len(messages); i++ { + msg := messages[i] + + // System messages are handled separately + if msg.Role == chat.MessageRoleSystem { + systemBlocks = append(systemBlocks, &types.SystemContentBlockMemberText{ + Value: msg.Content, + }) + continue + } + + // Convert role + var role types.ConversationRole + switch msg.Role { + case chat.MessageRoleUser: + role = types.ConversationRoleUser + case chat.MessageRoleAssistant: + role = types.ConversationRoleAssistant + case chat.MessageRoleTool: + // Tool results are sent as user messages with tool result blocks + role = types.ConversationRoleUser + default: + return nil, nil, fmt.Errorf("unsupported message role: %s", msg.Role) + } + + // Build content blocks + var contentBlocks []types.ContentBlock + + // Handle tool results - group consecutive tool results into one user message + if msg.Role == chat.MessageRoleTool && msg.ToolCallID != "" { + // Collect all consecutive tool results + toolResults := []chat.Message{msg} + j := i + 1 + for j < len(messages) && messages[j].Role == chat.MessageRoleTool { + toolResults = append(toolResults, messages[j]) + j++ + } + + // Convert all tool results into content blocks + for _, tr := range toolResults { + var toolResultContent []types.ToolResultContentBlock + toolResultContent = append(toolResultContent, &types.ToolResultContentBlockMemberText{ + Value: tr.Content, + }) + + contentBlocks = append(contentBlocks, &types.ContentBlockMemberToolResult{ + Value: types.ToolResultBlock{ + ToolUseId: aws.String(tr.ToolCallID), + Content: toolResultContent, + }, + }) + } + + // Skip the tool results we already processed + i = j - 1 + } else if msg.Role == chat.MessageRoleAssistant && len(msg.ToolCalls) > 0 { + // Assistant message with tool calls + if msg.Content != "" { + contentBlocks = append(contentBlocks, &types.ContentBlockMemberText{ + Value: msg.Content, + }) + } + + // Add tool use blocks + for _, tc := range msg.ToolCalls { + // Parse tool arguments + var input map[string]any + if tc.Function.Arguments != "" { + if err := json.Unmarshal([]byte(tc.Function.Arguments), &input); err != nil { + slog.Warn("Failed to unmarshal tool arguments", "tool", tc.Function.Name, "error", err) + input = make(map[string]any) + } + } else { + input = make(map[string]any) + } + + // Convert to document type + inputDoc, err := convertToDocument(input) + if err != nil { + return nil, nil, fmt.Errorf("failed to convert tool input: %w", err) + } + + contentBlocks = append(contentBlocks, &types.ContentBlockMemberToolUse{ + Value: types.ToolUseBlock{ + ToolUseId: aws.String(tc.ID), + Name: aws.String(tc.Function.Name), + Input: inputDoc, + }, + }) + } + } else { + // Regular text message + if msg.Content != "" { + contentBlocks = append(contentBlocks, &types.ContentBlockMemberText{ + Value: msg.Content, + }) + } + } + + if len(contentBlocks) > 0 { + converseMessages = append(converseMessages, types.Message{ + Role: role, + Content: contentBlocks, + }) + } + } + + return converseMessages, systemBlocks, nil +} + +// convertToConverseTools converts cagent tools to Converse API format +func convertToConverseTools(tools []cagentTools.Tool) ([]types.Tool, error) { + var converseTools []types.Tool + + for _, tool := range tools { + // Convert tool parameters schema + schemaMap, err := cagentTools.SchemaToMap(tool.Parameters) + if err != nil { + return nil, fmt.Errorf("failed to convert tool %s schema: %w", tool.Name, err) + } + + // Convert schema to document + inputSchema, err := convertToDocument(schemaMap) + if err != nil { + return nil, fmt.Errorf("failed to convert tool %s input schema: %w", tool.Name, err) + } + + converseTools = append(converseTools, &types.ToolMemberToolSpec{ + Value: types.ToolSpecification{ + Name: aws.String(tool.Name), + Description: aws.String(tool.Description), + InputSchema: &types.ToolInputSchemaMemberJson{ + Value: inputSchema, + }, + }, + }) + } + + return converseTools, nil +} + +// convertToDocument converts a map to AWS document type +func convertToDocument(data map[string]any) (document.Interface, error) { + // Remove fields that Bedrock doesn't accept in JSON Schema + cleanedData := make(map[string]any) + for k, v := range data { + // Skip additionalProperties as Bedrock might not accept it + if k == "additionalProperties" { + continue + } + // Convert nil "required" field to empty array, as Bedrock expects an array + if k == "required" && v == nil { + cleanedData[k] = []string{} + } else { + cleanedData[k] = v + } + } + + slog.Debug("Converting to document", "data", cleanedData) + + // Create lazy document from the map structure directly, not JSON bytes + // NewLazyDocument will handle the marshaling internally + return document.NewLazyDocument(cleanedData), nil +} diff --git a/pkg/model/provider/bedrock/client_test.go b/pkg/model/provider/bedrock/client_test.go new file mode 100644 index 000000000..45a73bc29 --- /dev/null +++ b/pkg/model/provider/bedrock/client_test.go @@ -0,0 +1,174 @@ +package bedrock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/config/latest" + "github.com/docker/cagent/pkg/environment" +) + +// testMapProvider is a simple map-based environment provider for testing +type testMapProvider struct { + values map[string]string +} + +func newTestMapProvider(values map[string]string) *testMapProvider { + return &testMapProvider{values: values} +} + +func (p *testMapProvider) Get(_ context.Context, name string) string { + return p.values[name] +} + +var _ environment.Provider = (*testMapProvider)(nil) + +func TestNewClient_ValidConfig(t *testing.T) { + t.Parallel() + ctx := t.Context() + env := newTestMapProvider(map[string]string{ + "AWS_BEDROCK_TOKEN": "test-token", + "AWS_REGION": "us-west-2", + }) + + cfg := &latest.ModelConfig{ + Provider: "amazon-bedrock", + Model: "anthropic.claude-3-5-sonnet-20241022-v2:0", + } + + client, err := NewClient(ctx, cfg, env) + require.NoError(t, err) + require.NotNil(t, client) + assert.Equal(t, "us-west-2", client.region) + assert.Equal(t, "amazon-bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", client.ID()) +} + +func TestNewClient_CustomRegionFromConfig(t *testing.T) { + t.Parallel() + ctx := t.Context() + env := newTestMapProvider(map[string]string{ + "AWS_BEDROCK_TOKEN": "test-token", + }) + + cfg := &latest.ModelConfig{ + Provider: "amazon-bedrock", + Model: "amazon.titan-text-express-v1", + ProviderOpts: map[string]any{ + "region": "eu-west-1", + }, + } + + client, err := NewClient(ctx, cfg, env) + require.NoError(t, err) + require.NotNil(t, client) + assert.Equal(t, "eu-west-1", client.region) +} + +func TestNewClient_ProviderOptsRegionTakesPrecedenceOverEnv(t *testing.T) { + t.Parallel() + ctx := t.Context() + // Set both AWS_REGION env var and ProviderOpts region + env := newTestMapProvider(map[string]string{ + "AWS_BEDROCK_TOKEN": "test-token", + "AWS_REGION": "us-west-2", + }) + + cfg := &latest.ModelConfig{ + Provider: "amazon-bedrock", + Model: "anthropic.claude-3-5-sonnet-20241022-v2:0", + ProviderOpts: map[string]any{ + "region": "eu-central-1", // Explicit config should take precedence + }, + } + + client, err := NewClient(ctx, cfg, env) + require.NoError(t, err) + require.NotNil(t, client) + // ProviderOpts region should take precedence over AWS_REGION env var + assert.Equal(t, "eu-central-1", client.region) +} + +func TestNewClient_DefaultRegion(t *testing.T) { + t.Parallel() + ctx := t.Context() + env := newTestMapProvider(map[string]string{ + "AWS_BEDROCK_TOKEN": "test-token", + }) + + cfg := &latest.ModelConfig{ + Provider: "amazon-bedrock", + Model: "anthropic.claude-3-5-sonnet-20241022-v2:0", + } + + client, err := NewClient(ctx, cfg, env) + require.NoError(t, err) + require.NotNil(t, client) + assert.Equal(t, "us-east-1", client.region) // Default region +} + +func TestNewClient_CustomEndpoint(t *testing.T) { + t.Parallel() + ctx := t.Context() + env := newTestMapProvider(map[string]string{ + "AWS_BEDROCK_TOKEN": "test-token", + }) + + cfg := &latest.ModelConfig{ + Provider: "amazon-bedrock", + Model: "anthropic.claude-3-5-sonnet-20241022-v2:0", + BaseURL: "https://custom-bedrock-endpoint.example.com", + } + + client, err := NewClient(ctx, cfg, env) + require.NoError(t, err) + require.NotNil(t, client) +} + +func TestNewClient_WrongProviderType(t *testing.T) { + t.Parallel() + ctx := t.Context() + env := newTestMapProvider(map[string]string{ + "AWS_BEDROCK_TOKEN": "test-token", + }) + + cfg := &latest.ModelConfig{ + Provider: "openai", + Model: "gpt-4", + } + + _, err := NewClient(ctx, cfg, env) + require.Error(t, err) + assert.Contains(t, err.Error(), "model type must be 'amazon-bedrock'") +} + +func TestNewClient_NilConfig(t *testing.T) { + t.Parallel() + ctx := t.Context() + env := newTestMapProvider(map[string]string{}) + + _, err := NewClient(ctx, nil, env) + require.Error(t, err) + assert.Contains(t, err.Error(), "model configuration is required") +} + +func TestClientID(t *testing.T) { + t.Parallel() + ctx := t.Context() + env := newTestMapProvider(map[string]string{ + "AWS_BEDROCK_TOKEN": "test-token", + }) + + cfg := &latest.ModelConfig{ + Provider: "amazon-bedrock", + Model: "anthropic.claude-3-5-sonnet-20241022-v2:0", + } + + client, err := NewClient(ctx, cfg, env) + require.NoError(t, err) + + expectedID := "amazon-bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0" + assert.Equal(t, expectedID, client.ID()) +} diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index 9842c608d..d0e333a43 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -41,6 +41,11 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro return nil, errors.New("model type must be 'google'") } + // Ensure env is not nil - use default provider if nil + if env == nil { + env = environment.NewDefaultProvider() + } + var globalOptions options.ModelOptions for _, opt := range opts { opt(&globalOptions) diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index 1a1d3fc58..c71c9efca 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -40,6 +40,11 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro return nil, errors.New("model configuration is required") } + // Ensure env is not nil - use default provider if nil + if env == nil { + env = environment.NewDefaultProvider() + } + var globalOptions options.ModelOptions for _, opt := range opts { opt(&globalOptions) diff --git a/pkg/model/provider/provider.go b/pkg/model/provider/provider.go index 963cf6aed..35e7eba4e 100644 --- a/pkg/model/provider/provider.go +++ b/pkg/model/provider/provider.go @@ -10,6 +10,7 @@ import ( "github.com/docker/cagent/pkg/environment" "github.com/docker/cagent/pkg/model/provider/anthropic" "github.com/docker/cagent/pkg/model/provider/base" + "github.com/docker/cagent/pkg/model/provider/bedrock" "github.com/docker/cagent/pkg/model/provider/dmr" "github.com/docker/cagent/pkg/model/provider/gemini" "github.com/docker/cagent/pkg/model/provider/openai" @@ -121,6 +122,9 @@ func New(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, case "dmr": return dmr.NewClient(ctx, enhancedCfg, opts...) + case "amazon-bedrock": + return bedrock.NewClient(ctx, enhancedCfg, env, opts...) + default: slog.Error("Unknown provider type", "type", providerType) return nil, fmt.Errorf("unknown provider type: %s", providerType)