diff --git a/aiprompts/conn-arch.md b/aiprompts/conn-arch.md new file mode 100644 index 000000000..b03abbad2 --- /dev/null +++ b/aiprompts/conn-arch.md @@ -0,0 +1,612 @@ +# Wave Terminal Connection Architecture + +## Overview + +Wave Terminal's connection system is designed to provide a unified interface for running shell processes across local, SSH, and WSL environments. The architecture is built in layers, with clear separation of concerns between connection management, shell process execution, and block-level orchestration. + +## Architecture Layers + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Block Controllers │ +│ (blockcontroller/blockcontroller.go, shellcontroller.go) │ +│ - Block lifecycle management │ +│ - Controller registry and switching │ +│ - Connection status verification │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Connection Controllers (ConnUnion) │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ Local │ │ SSH │ │ WSL │ │ +│ │ │ │ (conncontrol │ │ (wslconn) │ │ +│ │ │ │ ler) │ │ │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +│ - Connection lifecycle (init → connecting → connected) │ +│ - WSH (Wave Shell Extensions) management │ +│ - Domain socket setup for RPC communication │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Shell Process Execution │ +│ (shellexec/shellexec.go) │ +│ - ShellProc wrapper for running processes │ +│ - PTY management │ +│ - Process lifecycle (start, wait, kill) │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Low-Level Connection Implementation │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ os/exec │ │golang.org/x/ │ │ pkg/wsl │ │ +│ │ │ │ crypto/ssh │ │ │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +│ - Local process spawning │ +│ - SSH protocol implementation │ +│ - WSL command execution │ +└─────────────────────────────────────────────────────────────────┘ +``` + +## Key Components + +### 1. Block Controllers (`pkg/blockcontroller/`) + +**Primary Files:** +- [`blockcontroller.go`](../pkg/blockcontroller/blockcontroller.go) - Controller registry and orchestration +- [`shellcontroller.go`](../pkg/blockcontroller/shellcontroller.go) - Shell/terminal controller implementation + +**Responsibilities:** +- **Controller Registry**: Maintains a global map of active block controllers (`controllerRegistry`) +- **Lifecycle Management**: Handles controller creation, starting, stopping, and switching +- **Connection Verification**: Checks connection status before starting shell processes ([`CheckConnStatus()`](../pkg/blockcontroller/blockcontroller.go:360)) +- **Controller Types**: Supports different controller types (shell, cmd, tsunami) + +**Key Functions:** +- [`ResyncController()`](../pkg/blockcontroller/blockcontroller.go:120) - Main entry point for synchronizing block state with desired controller +- [`registerController()`](../pkg/blockcontroller/blockcontroller.go:84) - Registers a new controller, stopping any existing one +- [`getController()`](../pkg/blockcontroller/blockcontroller.go:78) - Retrieves active controller for a block + +**ShellController Details:** +- Implements the `Controller` interface +- Manages shell processes via [`ShellProc`](../pkg/shellexec/shellexec.go:48) +- Handles three connection types via `ConnUnion`: + - **Local**: Direct process execution on local machine + - **SSH**: Remote execution via SSH connections + - **WSL**: Windows Subsystem for Linux execution +- Key methods: + - [`setupAndStartShellProcess()`](../pkg/blockcontroller/shellcontroller.go:364) - Sets up and starts shell process + - [`getConnUnion()`](../pkg/blockcontroller/shellcontroller.go:321) - Determines connection type and retrieves connection object + - [`manageRunningShellProcess()`](../pkg/blockcontroller/shellcontroller.go:500+) - Manages I/O for running process + +### 2. Connection Controllers + +#### SSH Connections (`pkg/remote/conncontroller/`) + +**Primary File:** [`conncontroller.go`](../pkg/remote/conncontroller/conncontroller.go) + +**Architecture:** +- **Global Registry**: `clientControllerMap` maintains all SSH connections +- **Connection Lifecycle**: + ``` + init → connecting → connected → (running) → disconnected/error + ``` +- **Thread Safety**: Each connection has its own lock (`SSHConn.Lock`) + +**SSHConn Structure:** +```go +type SSHConn struct { + Lock *sync.Mutex + Status string // Connection state + WshEnabled *atomic.Bool // WSH availability flag + Opts *remote.SSHOpts // Connection parameters + Client *ssh.Client // Underlying SSH client + DomainSockName string // Unix socket for RPC + DomainSockListener net.Listener // Socket listener + ConnController *ssh.Session // Runs "wsh connserver" + Error string // Connection error + WshError string // WSH-specific error + WshVersion string // Installed WSH version + // ... +} +``` + +**Key Responsibilities:** +1. **SSH Client Management**: + - Establishes SSH connections using [`golang.org/x/crypto/ssh`](https://pkg.go.dev/golang.org/x/crypto/ssh) + - Handles authentication (pubkey, password, keyboard-interactive) + - Supports ProxyJump for multi-hop connections + +2. **Domain Socket Setup** ([`OpenDomainSocketListener()`](../pkg/remote/conncontroller/conncontroller.go:201)): + - Creates Unix domain socket on remote host (`/tmp/waveterm-*.sock`) + - Enables bidirectional RPC communication + - Socket used by both connserver and shell processes + +3. **WSH (Wave Shell Extensions) Management**: + - **Version Check** ([`StartConnServer()`](../pkg/remote/conncontroller/conncontroller.go:277)): Runs `wsh version` to check installation + - **Installation** ([`InstallWsh()`](../pkg/remote/conncontroller/conncontroller.go:478)): Copies appropriate WSH binary to remote + - **Update** ([`UpdateWsh()`](../pkg/remote/conncontroller/conncontroller.go:417)): Updates existing WSH installation + - **User Prompts** ([`getPermissionToInstallWsh()`](../pkg/remote/conncontroller/conncontroller.go:434)): Asks user for install permission + +4. **Connection Server** (`wsh connserver`): + - Long-running process on remote host + - Provides RPC services for file operations, command execution, etc. + - Communicates via domain socket + - Template: [`ConnServerCmdTemplate`](../pkg/remote/conncontroller/conncontroller.go:74) + +**Connection Flow:** +``` +1. GetConn(opts) - Retrieve or create connection +2. Connect(ctx) - Initiate connection +3. CheckIfNeedsAuth() - Verify authentication needed +4. OpenDomainSocketListener() - Set up RPC channel +5. StartConnServer() - Launch wsh connserver +6. (Install/Update WSH if needed) +7. Status: Connected - Ready for shell processes +``` + +#### SSH Client (`pkg/remote/sshclient.go`) + +**Responsibilities:** +- **Authentication Methods**: + - Public key with optional passphrase ([`createPublicKeyCallback()`](../pkg/remote/sshclient.go:118)) + - Password authentication ([`createPasswordCallbackPrompt()`](../pkg/remote/sshclient.go:227)) + - Keyboard-interactive ([`createInteractiveKbdInteractiveChallenge()`](../pkg/remote/sshclient.go:264)) + - SSH agent support + +- **Known Hosts Verification** ([`createHostKeyCallback()`](../pkg/remote/sshclient.go:429)): + - Reads `~/.ssh/known_hosts` and global known_hosts + - Prompts user for unknown hosts + - Handles key changes/mismatches + +- **ProxyJump Support**: + - Recursive connection through jump hosts + - Max depth: `SshProxyJumpMaxDepth = 10` + +- **User Interaction**: + - Integrates with Wave's [`userinput`](../pkg/userinput/) system + - Non-blocking prompts for passwords, passphrases, host verification + +#### WSL Connections (`pkg/wslconn/`) + +**Primary File:** [`wslconn.go`](../pkg/wslconn/wslconn.go) + +**Architecture:** +- **Similar to SSH**: Parallel structure to `conncontroller` but for WSL +- **Global Registry**: `clientControllerMap` for WSL connections +- **Connection Naming**: `wsl://[distro-name]` (e.g., `wsl://Ubuntu`) + +**WslConn Structure:** +```go +type WslConn struct { + Lock *sync.Mutex + Status string + WshEnabled *atomic.Bool + Name wsl.WslName // Distro name + Client *wsl.Distro // WSL distro interface + DomainSockName string // Uses RemoteFullDomainSocketPath + ConnController *wsl.WslCmd // Runs "wsh connserver" + // ... similar to SSHConn +} +``` + +**Key Differences from SSH:** +- **No Network Socket**: WSL processes run locally, no SSH connection needed +- **Domain Socket Path**: Uses predetermined path ([`wavebase.RemoteFullDomainSocketPath`](../pkg/wavebase/)) +- **Command Execution**: Uses `wsl.exe` command-line tool +- **Simpler Authentication**: No auth needed, user already logged into Windows + +**Connection Flow:** +``` +1. GetWslConn(distroName) - Get/create WSL connection +2. Connect(ctx) - Start connection process +3. OpenDomainSocketListener() - Set domain socket path (no actual listener) +4. StartConnServer() - Launch wsh connserver in WSL +5. (Install/Update WSH if needed) +6. Status: Connected - Ready for shell processes +``` + +### 3. Shell Process Execution (`pkg/shellexec/`) + +**Primary File:** [`shellexec.go`](../pkg/shellexec/shellexec.go) + +**ShellProc Structure:** +```go +type ShellProc struct { + ConnName string // Connection identifier + Cmd ConnInterface // Actual process interface + CloseOnce *sync.Once // Ensures single close + DoneCh chan any // Signals process completion + WaitErr error // Process exit status +} +``` + +**ConnInterface Implementations:** +- **Local**: [`CombinedConnInterface`](../pkg/shellexec/) wraps `os/exec.Cmd` with PTY +- **SSH**: [`RemoteConnInterface`](../pkg/shellexec/) wraps SSH session +- **WSL**: [`WslConnInterface`](../pkg/shellexec/) wraps WSL command + +**Process Startup Functions:** +- [`StartLocalShellProc()`](../pkg/shellexec/) - Local shell processes +- [`StartRemoteShellProc()`](../pkg/shellexec/) - SSH remote shells (with WSH) +- [`StartRemoteShellProcNoWsh()`](../pkg/shellexec/) - SSH remote shells (no WSH) +- [`StartWslShellProc()`](../pkg/shellexec/) - WSL shells (with WSH) +- [`StartWslShellProcNoWsh()`](../pkg/shellexec/) - WSL shells (no WSH) + +**Key Features:** +- **PTY Management**: Pseudo-terminal for interactive shells +- **Graceful Shutdown**: Sends SIGTERM, waits briefly, then SIGKILL +- **Process Wrapping**: Abstracts differences between local/remote/WSL execution + +### 4. Generic Connection Interface (`pkg/genconn/`) + +**Purpose**: Provides abstraction layer for running commands across different connection types + +**Primary File:** [`ssh-impl.go`](../pkg/genconn/ssh-impl.go) + +**Interface Hierarchy:** +```go +ShellClient -> ShellProcessController +``` + +**SSHShellClient:** +- Wraps `*ssh.Client` +- Creates `SSHProcessController` for each command + +**SSHProcessController:** +- Wraps `*ssh.Session` +- Implements stdio piping (stdin, stdout, stderr) +- Handles command lifecycle (Start, Wait, Kill) +- Thread-safe with internal locking + +**Usage Pattern:** +```go +client := genconn.MakeSSHShellClient(sshClient) +proc, _ := client.MakeProcessController(cmdSpec) +stdout, _ := proc.StdoutPipe() +proc.Start() +// Read from stdout... +proc.Wait() +``` + +### 5. Shell Utilities (`pkg/util/shellutil/`) + +**Primary File:** [`shellutil.go`](../pkg/util/shellutil/shellutil.go) + +**Responsibilities:** + +1. **Shell Detection**: + - [`DetectLocalShellPath()`](../pkg/util/shellutil/shellutil.go:87) - Finds user's default shell + - [`GetShellTypeFromShellPath()`](../pkg/util/shellutil/shellutil.go:462) - Identifies shell type (bash, zsh, fish, pwsh) + - [`DetectShellTypeAndVersion()`](../pkg/util/shellutil/shellutil.go:486) - Gets shell version info + +2. **Shell Integration Files**: + - [`InitCustomShellStartupFiles()`](../pkg/util/shellutil/shellutil.go:270) - Creates Wave's shell integration + - Manages startup files for each shell type: + - Bash: `.bashrc` in `shell/bash/` + - Zsh: `.zshrc`, `.zprofile`, etc. in `shell/zsh/` + - Fish: `wave.fish` in `shell/fish/` + - PowerShell: `wavepwsh.ps1` in `shell/pwsh/` + +3. **Environment Management**: + - [`WaveshellLocalEnvVars()`](../pkg/util/shellutil/shellutil.go:218) - Wave-specific environment variables + - [`UpdateCmdEnv()`](../pkg/util/shellutil/shellutil.go:231) - Updates command environment + +4. **WSH Binary Management**: + - [`GetLocalWshBinaryPath()`](../pkg/util/shellutil/shellutil.go:334) - Locates platform-specific WSH binary + - Supports multiple OS/arch combinations + +5. **Git Bash Detection** (Windows): + - [`FindGitBash()`](../pkg/util/shellutil/shellutil.go:156) - Locates Git Bash installation + - Checks multiple common installation paths + +## Connection Types and Workflows + +### Local Connections + +**Connection Name**: `"local"`, `"local:"`, or `""` (empty) + +**Workflow:** +1. Block controller checks connection type via [`IsLocalConnName()`](../pkg/remote/conncontroller/conncontroller.go:80) +2. No connection setup needed +3. Shell process started directly via [`StartLocalShellProc()`](../pkg/shellexec/) +4. Uses `os/exec.Cmd` with PTY +5. WSH integration via environment variables + +**Special Case - Git Bash (Windows):** +- Variant: `"local:gitbash"` +- Requires special shell path detection +- Uses Git Bash binary instead of default shell + +### SSH Connections + +**Connection Name**: `"user@host:port"` (parsed by [`remote.ParseOpts()`](../pkg/remote/)) + +**Full Connection Workflow:** + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ 1. Connection Request (from Block Controller) │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ 2. GetConn(opts) - Retrieve/Create SSHConn │ +│ - Check global registry (clientControllerMap) │ +│ - Create new SSHConn if needed │ +│ - Status: "init" │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ 3. conn.Connect(ctx) - Establish SSH Connection │ +│ - Status: "connecting" │ +│ - Read SSH config (~/.ssh/config) │ +│ - Resolve ProxyJump if configured │ +│ - Create SSH client auth methods: │ +│ • Public key (with agent support) │ +│ • Password │ +│ • Keyboard-interactive │ +│ - Establish SSH connection │ +│ - Verify known_hosts │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ 4. OpenDomainSocketListener(ctx) - Set Up RPC Channel │ +│ - Create random socket path: /tmp/waveterm-[random].sock │ +│ - Use ssh.Client.ListenUnix() for remote forwarding │ +│ - Start RPC listener goroutine │ +│ - Socket available for all subsequent operations │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ 5. StartConnServer(ctx) - Launch Wave Shell Extensions │ +│ - Run: "wsh version" to check installation │ +│ - If not installed or outdated: │ +│ a. Detect remote platform (OS/arch) │ +│ b. Get user permission (if configured) │ +│ c. InstallWsh() - Copy binary to remote │ +│ d. Retry StartConnServer() │ +│ - Run: "wsh connserver" on remote │ +│ - Pass JWT token for authentication │ +│ - Monitor connserver output │ +│ - Wait for RPC route registration │ +│ - Status: "connected" │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ 6. Connection Ready - Can Start Shell Processes │ +│ - SSHConn available in registry │ +│ - Domain socket active for RPC │ +│ - WSH connserver running │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ 7. Start Shell Process (from ShellController) │ +│ - setupAndStartShellProcess() │ +│ - Create swap token (for shell integration) │ +│ - StartRemoteShellProc() or StartRemoteShellProcNoWsh() │ +│ - SSH session created for shell │ +│ - PTY allocated │ +│ - Shell starts with Wave integration │ +└─────────────────────────────────────────────────────────────────┘ +``` + +**WSH (Wave Shell Extensions) Details:** + +**What is WSH?** +- Binary program (`wsh`) that runs on remote hosts +- Provides RPC services for Wave Terminal +- Written in Go, cross-platform +- Versioned to match Wave Terminal version + +**WSH Components:** +1. **wsh version**: Reports installed version +2. **wsh connserver**: Long-running RPC server + - Handles file operations + - Executes commands + - Provides remote state information + - Communicates over domain socket + +**WSH Installation Process:** +1. Check if wsh is installed: Run `wsh version` +2. If not installed: Detect platform with `uname -sm` +3. Get appropriate binary from local cache +4. Copy to remote: `~/.waveterm/bin/wsh` +5. Set executable permissions +6. Restart connection process + +**With vs Without WSH:** +- **With WSH**: Full RPC support, better integration, file sync +- **Without WSH**: Basic shell only, limited features +- Fallback to no-WSH mode on installation failure + +### WSL Connections + +**Connection Name**: `"wsl://[distro]"` (e.g., `"wsl://Ubuntu"`) + +**Workflow:** +``` +1. GetWslConn(distroName) - Get/create WslConn +2. conn.Connect(ctx) - Start connection +3. OpenDomainSocketListener() - Set socket path (no actual listener) +4. StartConnServer() - Launch "wsh connserver" via wsl.exe +5. Install/update WSH if needed (similar to SSH) +6. Status: "connected" +7. StartWslShellProc() - Create shell process in WSL +``` + +**Key Differences from SSH:** +- Uses `wsl.exe` command-line tool +- No network connection overhead +- Predetermined domain socket path +- Simpler authentication (inherited from Windows) + +## Token Swap System + +**Purpose**: Pass connection-specific environment variables to shell processes + +**Implementation:** [`shellutil.TokenSwapEntry`](../pkg/util/shellutil/) + +**Flow:** +1. ShellController creates swap token before starting process +2. Token contains: + - Socket name for RPC + - JWT token for authentication + - RPC context (TabId, BlockId, Conn) + - Custom environment variables +3. Token stored in global swap map +4. Shell process receives token ID via environment +5. Shell integration scripts swap token for actual values +6. Token removed from map after use + +**Purpose:** +- Avoid exposing JWT tokens in process listings +- Enable shell integration without hardcoded values +- Support multiple shells on same connection + +## Error Handling and Recovery + +### Connection Failures + +**SSH Connection Errors:** +- Authentication failure → Prompt user (password, passphrase) +- Host key mismatch → Prompt for verification +- Network timeout → Status: "error", display error message +- ProxyJump failure → Error shows which jump host failed + +**Recovery Mechanisms:** +- [`conn.Reconnect(ctx)`](../pkg/remote/conncontroller/) - Close and re-establish connection +- [`conn.WaitForConnect(ctx)`](../pkg/remote/conncontroller/) - Block until connected +- Automatic fallback to no-WSH mode on installation failure + +### Process Failures + +**Shell Process Errors:** +- Process crash → WaitErr contains exit code +- PTY failure → Captured in error message +- I/O errors → Logged and surfaced to user + +**Cleanup:** +- [`ShellProc.Close()`](../pkg/shellexec/shellexec.go:56) - Graceful then forceful kill +- [`SSHConn.close_nolock()`](../pkg/remote/conncontroller/conncontroller.go:167) - Cleanup all resources +- [`deleteController()`](../pkg/blockcontroller/blockcontroller.go:101) - Remove from registry + +## Configuration Integration + +### Connection Configuration + +**Source:** [`pkg/wconfig/`](../pkg/wconfig/) + +**Per-Connection Settings:** +- `conn:wshenabled` - Enable/disable WSH +- `conn:wshpath` - Custom WSH binary path +- `conn:shellpath` - Custom shell path + +**Global Settings:** +- `conn:askbeforewshinstall` - Prompt before WSH installation +- Stored in `~/.waveterm/config/settings.json` +- Per-connection overrides in `~/.waveterm/config/connections.json` + +### SSH Configuration + +**Source:** `~/.ssh/config` + +**Supported Directives:** +- `Host` - Connection matching +- `HostName` - Target hostname +- `Port` - SSH port +- `User` - Username +- `IdentityFile` - Private key paths +- `ProxyJump` - Jump host specification +- `UserKnownHostsFile` - Known hosts file +- `GlobalKnownHostsFile` - System known hosts +- `AddKeysToAgent` - Add keys to SSH agent + +**Library:** [`github.com/kevinburke/ssh_config`](https://github.com/kevinburke/ssh_config) + +## Thread Safety + +### Synchronization Patterns + +**SSHConn/WslConn:** +```go +conn.Lock.Lock() +defer conn.Lock.Unlock() +// ... modify connection state +``` + +**Atomic Flags:** +```go +conn.WshEnabled.Load() // Read WSH enabled status +conn.WshEnabled.Store(v) // Update atomically +``` + +**Controller Registry:** +```go +registryLock.RLock() // Read lock for lookups +registryLock.Lock() // Write lock for modifications +``` + +**ShellProc Completion:** +```go +sp.CloseOnce.Do(func() { // Ensure single execution + sp.WaitErr = waitErr + close(sp.DoneCh) // Signal completion +}) +``` + +## Event System Integration + +### Connection Events + +**Published via:** [`pkg/wps/`](../pkg/wps/) (Wave Publish/Subscribe) + +**Event Types:** +- `Event_ConnChange` - Connection status changed +- `Event_ControllerStatus` - Block controller status update +- `Event_BlockFile` - Block file operation (terminal output) + +**Example:** +```go +wps.Broker.Publish(wps.WaveEvent{ + Event: wps.Event_ConnChange, + Scopes: []string{fmt.Sprintf("connection:%s", connName)}, + Data: connStatus, +}) +``` + +**Frontend Integration:** +- Events received via WebSocket +- Connection status updates UI +- Real-time terminal output streaming + +## Summary of Responsibilities + +| Component | Responsibilities | +|-----------|-----------------| +| **blockcontroller/** | Block lifecycle, controller registry, connection coordination | +| **shellcontroller** | Shell process management, ConnUnion abstraction, I/O handling | +| **conncontroller/** | SSH connection lifecycle, WSH management, domain socket setup | +| **wslconn/** | WSL connection lifecycle, parallel to SSH but for WSL | +| **sshclient.go** | Low-level SSH: auth, known_hosts, ProxyJump | +| **shellexec/** | Process execution abstraction, PTY management | +| **genconn/** | Generic command execution interface | +| **shellutil/** | Shell detection, integration files, environment setup | + +## Key Design Principles + +1. **Layered Architecture**: Clear separation between block management, connection management, and process execution + +2. **Connection Abstraction**: ConnUnion pattern allows uniform handling of Local/SSH/WSL + +3. **WSH Optional**: System works with and without Wave Shell Extensions, degrading gracefully + +4. **Thread Safety**: Defensive locking, atomic flags, singleton patterns prevent race conditions + +5. **Error Recovery**: Multiple retry mechanisms, fallback modes, user prompts for resolution + +6. **Configuration Hierarchy**: Global → Connection-Specific → Runtime overrides + +7. **Event-Driven Updates**: Real-time status updates via pub/sub system + +8. **User Interaction**: Non-blocking prompts for passwords, confirmations, installations + +This architecture provides a robust foundation for Wave Terminal's multi-environment shell capabilities, with clear extension points for adding new connection types or capabilities. \ No newline at end of file diff --git a/aiprompts/fe-conn-arch.md b/aiprompts/fe-conn-arch.md new file mode 100644 index 000000000..eafb46cea --- /dev/null +++ b/aiprompts/fe-conn-arch.md @@ -0,0 +1,1007 @@ +# Wave Terminal Frontend Connection Architecture + +## Overview + +The frontend connection architecture provides a reactive interface for managing and interacting with connections (local, SSH, WSL, S3). It follows a unidirectional data flow pattern where the backend manages connection state, the frontend observes this state through Jotai atoms, and user interactions trigger backend operations via RPC commands. + +## Architecture Pattern + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ User Interface │ +│ - ConnectionButton (displays status) │ +│ - ChangeConnectionBlockModal (connection picker) │ +│ - ConnStatusOverlay (error states) │ +└─────────────────────────────────────────────────────────────────┘ + ↕ +┌─────────────────────────────────────────────────────────────────┐ +│ Jotai Reactive State │ +│ - ConnStatusMapAtom (connection statuses) │ +│ - View Model Atoms (derived connection state) │ +│ - Block Metadata (connection selection) │ +└─────────────────────────────────────────────────────────────────┘ + ↕ +┌─────────────────────────────────────────────────────────────────┐ +│ RPC Commands │ +│ - ConnListCommand (list connections) │ +│ - ConnEnsureCommand (ensure connected) │ +│ - ConnConnectCommand/ConnDisconnectCommand │ +│ - SetMetaCommand (change block connection) │ +│ - ControllerInputCommand (send data to shell) │ +└─────────────────────────────────────────────────────────────────┘ + ↕ +┌─────────────────────────────────────────────────────────────────┐ +│ Backend (see conn-arch.md) │ +│ - Connection Controllers (SSHConn, WslConn) │ +│ - Block Controllers (ShellController) │ +│ - Shell Process Execution │ +└─────────────────────────────────────────────────────────────────┘ +``` + +## Key Components + +### 1. Connection State Management ([`frontend/app/store/global.ts`](../frontend/app/store/global.ts)) + +**ConnStatusMapAtom** +```typescript +const ConnStatusMapAtom = atom(new Map>()) +``` + +- Global registry of connection status atoms +- One atom per connection (keyed by connection name) +- Backend updates status via wave events +- Frontend components subscribe to individual connection atoms + +**getConnStatusAtom()** +```typescript +function getConnStatusAtom(connName: string): PrimitiveAtom +``` + +- Retrieves or creates status atom for a connection +- Returns cached atom if exists +- Creates new atom initialized to default if needed +- Used by view models to track their connection + +**ConnStatus Structure** +```typescript +interface ConnStatus { + status: "init" | "connecting" | "connected" | "disconnected" | "error" + connection: string // Connection name + connected: boolean // Is currently connected + activeconnnum: number // Color assignment number (1-8) + wshenabled: boolean // WSH available on this connection + error?: string // Error message if status is "error" + wsherror?: string // WSH-specific error +} +``` + +**allConnStatusAtom** +```typescript +const allConnStatusAtom = atom((get) => { + const connStatusMap = get(ConnStatusMapAtom) + const connStatuses = Array.from(connStatusMap.values()).map((atom) => get(atom)) + return connStatuses +}) +``` + +- Provides array of all connection statuses +- Used by connection modal to display all available connections +- Automatically updates when any connection status changes + +### 2. Connection Button UI ([`frontend/app/block/blockutil.tsx`](../frontend/app/block/blockutil.tsx)) + +**ConnectionButton Component** + +```typescript +export const ConnectionButton = React.memo( + React.forwardRef( + ({ connection, changeConnModalAtom }, ref) => { + const connStatusAtom = getConnStatusAtom(connection) + const connStatus = jotai.useAtomValue(connStatusAtom) + // ... renders connection status with colored icon + } + ) +) +``` + +**Responsibilities:** +- Displays connection name and status icon +- Color-codes connections (8 colors, cycling) +- Shows visual states: + - **Local**: Laptop icon (grey) + - **Connecting**: Animated dots (yellow/warning) + - **Connected**: Arrow icon (colored by activeconnnum) + - **Error**: Slashed arrow icon (red) + - **Disconnected**: Slashed arrow icon (grey) +- Opens connection modal on click + +**Color Assignment:** +```typescript +function computeConnColorNum(connStatus: ConnStatus): number { + const connColorNum = (connStatus?.activeconnnum ?? 1) % NumActiveConnColors + return connColorNum == 0 ? NumActiveConnColors : connColorNum +} +``` + +- Backend assigns `activeconnnum` sequentially +- Frontend cycles through 8 CSS color variables +- `var(--conn-icon-color-1)` through `var(--conn-icon-color-8)` + +### 3. Connection Selection Modal ([`frontend/app/modals/conntypeahead.tsx`](../frontend/app/modals/conntypeahead.tsx)) + +**ChangeConnectionBlockModal Component** + +**Data Fetching:** +```typescript +useEffect(() => { + if (!changeConnModalOpen) return + + // Fetch available connections + RpcApi.ConnListCommand(TabRpcClient, { timeout: 2000 }) + .then(setConnList) + + RpcApi.WslListCommand(TabRpcClient, { timeout: 2000 }) + .then(setWslList) + + RpcApi.ConnListAWSCommand(TabRpcClient, { timeout: 2000 }) + .then(setS3List) +}, [changeConnModalOpen]) +``` + +**Connection Change Handler:** +```typescript +const changeConnection = async (connName: string) => { + // Update block metadata with new connection + await RpcApi.SetMetaCommand(TabRpcClient, { + oref: WOS.makeORef("block", blockId), + meta: { + connection: connName, + file: newFile, // Reset file path for new connection + "cmd:cwd": null // Clear working directory + } + }) + + // Ensure connection is established + await RpcApi.ConnEnsureCommand(TabRpcClient, { + connname: connName, + logblockid: blockId + }, { timeout: 60000 }) +} +``` + +**Suggestion Categories:** +1. **Local Connections** + - Local machine (`""` or `"local:"`) + - Git Bash (Windows only: `"local:gitbash"`) + - WSL distros (`"wsl://Ubuntu"`, etc.) + +2. **Remote Connections** (SSH) + - User-configured SSH connections + - Format: `"user@host"` or `"user@host:port"` + - Filtered by `display:hidden` config + +3. **S3 Connections** (optional) + - AWS S3 profiles + - Format: `"aws:profile-name"` + +4. **Actions** + - Reconnect (if disconnected/error) + - Disconnect (if connected) + - Edit Connections (opens config editor) + - New Connection (creates new SSH config) + +**Filtering Logic:** +```typescript +function filterConnections( + connList: Array, + connSelected: string, + fullConfig: FullConfigType, + filterOutNowsh: boolean +): Array { + const connectionsConfig = fullConfig.connections + return connList.filter((conn) => { + const hidden = connectionsConfig?.[conn]?.["display:hidden"] ?? false + const wshEnabled = connectionsConfig?.[conn]?.["conn:wshenabled"] ?? true + return conn.includes(connSelected) && + !hidden && + (wshEnabled || !filterOutNowsh) + }) +} +``` + +### 4. Connection Status Overlay ([`frontend/app/block/blockframe.tsx`](../frontend/app/block/blockframe.tsx)) + +**ConnStatusOverlay Component** + +Displays over block content when: +- Connection is disconnected or in error state +- WSH installation/update errors occur +- Not in layout mode (Ctrl+Shift held) +- Connection modal is not open + +**Features:** +- Shows connection status text +- Displays error messages (scrollable) +- Reconnect button (for disconnected/error) +- "Always disable wsh" button (for WSH errors) +- Adaptive layout based on width + +**Handlers:** +```typescript +// Reconnect to failed connection +const handleTryReconnect = () => { + RpcApi.ConnConnectCommand(TabRpcClient, { + host: connName, + logblockid: nodeModel.blockId + }, { timeout: 60000 }) +} + +// Disable WSH for this connection +const handleDisableWsh = async () => { + await RpcApi.SetConnectionsConfigCommand(TabRpcClient, { + host: connName, + metamaptype: { "conn:wshenabled": false } + }) +} +``` + +### 5. View Model Integration + +View models integrate connection state into their reactive data flow: + +#### Terminal View Model ([`frontend/app/view/term/term-model.ts`](../frontend/app/view/term/term-model.ts)) + +```typescript +class TermViewModel implements ViewModel { + // Connection management flag + manageConnection = atom((get) => { + const termMode = get(this.termMode) + if (termMode == "vdom") return false // VDOM mode doesn't show conn button + + const isCmd = get(this.isCmdController) + if (isCmd) return false // Cmd controller doesn't manage connections + + return true // Standard terminals show connection button + }) + + // Connection status for this block + connStatus = atom((get) => { + const blockData = get(this.blockAtom) + const connName = blockData?.meta?.connection + const connAtom = getConnStatusAtom(connName) + return get(connAtom) + }) + + // Filter connections without WSH + filterOutNowsh = atom(false) +} +``` + +**End Icon Button Logic:** +```typescript +endIconButtons = atom((get) => { + const connStatus = get(this.connStatus) + const shellProcStatus = get(this.shellProcStatus) + + // Only show restart button if connected + if (connStatus?.status != "connected") { + return [] + } + + // Show appropriate icon based on shell state + if (shellProcStatus == "init") { + return [{ icon: "play", title: "Click to Start Shell" }] + } else if (shellProcStatus == "running") { + return [{ icon: "refresh", title: "Shell Running. Click to Restart" }] + } else if (shellProcStatus == "done") { + return [{ icon: "refresh", title: "Shell Exited. Click to Restart" }] + } +}) +``` + +#### Preview View Model ([`frontend/app/view/preview/preview-model.tsx`](../frontend/app/view/preview/preview-model.tsx)) + +```typescript +class PreviewModel implements ViewModel { + // Always manages connection + manageConnection = atom(true) + + // Connection status + connStatus = atom((get) => { + const blockData = get(this.blockAtom) + const connName = blockData?.meta?.connection + const connAtom = getConnStatusAtom(connName) + return get(connAtom) + }) + + // Filter out connections without WSH (file ops require WSH) + filterOutNowsh = atom(true) + + // Ensure connection before operations + connection = atom>(async (get) => { + const connName = get(this.blockAtom)?.meta?.connection + try { + await RpcApi.ConnEnsureCommand(TabRpcClient, { + connname: connName + }, { timeout: 60000 }) + globalStore.set(this.connectionError, "") + } catch (e) { + globalStore.set(this.connectionError, e as string) + } + return connName + }) +} +``` + +**File Operations Over Connection:** +```typescript +// Reads file from remote/local connection +statFile = atom>(async (get) => { + const fileName = get(this.metaFilePath) + const path = await this.formatRemoteUri(fileName, get) + + return await RpcApi.FileInfoCommand(TabRpcClient, { + info: { path } + }) +}) + +fullFile = atom>(async (get) => { + const fileName = get(this.metaFilePath) + const path = await this.formatRemoteUri(fileName, get) + + return await RpcApi.FileReadCommand(TabRpcClient, { + info: { path } + }) +}) +``` + +### 6. Block Controller Integration + +**View models do NOT directly manage shell processes.** They interact with block controllers via RPC: + +**Starting a Shell:** +```typescript +// User clicks restart button in terminal +forceRestartController() { + // Backend handles connection verification and process startup + RpcApi.ControllerRestartCommand(TabRpcClient, { + blockid: this.blockId, + force: true + }) +} +``` + +**Sending Input to Shell:** +```typescript +sendDataToController(data: string) { + const b64data = stringToBase64(data) + RpcApi.ControllerInputCommand(TabRpcClient, { + blockid: this.blockId, + inputdata64: b64data + }) +} +``` + +**Backend Block Controller Flow:** +1. Frontend calls `ControllerRestartCommand` +2. Backend `ShellController.Run()` starts +3. `CheckConnStatus()` verifies connection is ready +4. If not connected, triggers connection attempt +5. Once connected, `setupAndStartShellProcess()` +6. `getConnUnion()` retrieves appropriate connection (Local/SSH/WSL) +7. `StartLocalShellProc()`, `StartRemoteShellProc()`, or `StartWslShellProc()` +8. Process I/O managed by `manageRunningShellProcess()` + +## Connection Configuration + +### Hierarchical Configuration System + +Wave uses a three-level config hierarchy for connections: + +1. **Global Settings** (`settings`) +2. **Connection-Level Config** (`connections[connName]`) +3. **Block-Level Overrides** (`block.meta`) + +**Override Resolution:** +```typescript +function getOverrideConfigAtom(blockId: string, key: T): Atom { + return atom((get) => { + // 1. Check block metadata + const metaKeyVal = get(getBlockMetaKeyAtom(blockId, key)) + if (metaKeyVal != null) return metaKeyVal + + // 2. Check connection config + const connName = get(getBlockMetaKeyAtom(blockId, "connection")) + const connConfigKeyVal = get(getConnConfigKeyAtom(connName, key)) + if (connConfigKeyVal != null) return connConfigKeyVal + + // 3. Fall back to global settings + const settingsVal = get(getSettingsKeyAtom(key)) + return settingsVal ?? null + }) +} +``` + +### Common Connection Settings + +**Connection Keywords** (apply to specific connections): +- `conn:wshenabled` - Enable/disable WSH for this connection +- `conn:wshpath` - Custom WSH binary path +- `display:hidden` - Hide connection from selector +- `display:order` - Sort order in connection list +- `term:fontsize` - Font size for terminals on this connection +- `term:theme` - Color theme for terminals on this connection + +**Example Usage in View Models:** +```typescript +// Font size with connection override +fontSizeAtom = atom((get) => { + const blockData = get(this.blockAtom) + const connName = blockData?.meta?.connection + const fullConfig = get(atoms.fullConfigAtom) + + // Check: block meta > connection config > global settings + const fontSize = blockData?.meta?.["term:fontsize"] ?? + fullConfig?.connections?.[connName]?.["term:fontsize"] ?? + get(getSettingsKeyAtom("term:fontsize")) ?? + 12 + + return boundNumber(fontSize, 4, 64) +}) +``` + +## RPC Interface + +### Connection Management Commands + +**ConnListCommand** +```typescript +ConnListCommand(client: RpcClient): Promise +``` +- Returns list of configured SSH connection names +- Used by connection modal to populate remote connections +- Filters by `display:hidden` config on frontend + +**WslListCommand** +```typescript +WslListCommand(client: RpcClient): Promise +``` +- Returns list of installed WSL distribution names +- Windows only (silently fails on other platforms) +- Connection names formatted as `wsl://[distro]` + +**ConnListAWSCommand** +```typescript +ConnListAWSCommand(client: RpcClient): Promise +``` +- Returns list of AWS profile names from config +- Used for S3 preview connections +- Connection names formatted as `aws:[profile]` + +**ConnEnsureCommand** +```typescript +ConnEnsureCommand( + client: RpcClient, + data: { connname: string, logblockid?: string } +): Promise +``` +- Ensures connection is in "connected" state +- Triggers connection if not already connected +- Waits for connection to complete or timeout +- Used before file operations and by view models + +**ConnConnectCommand** +```typescript +ConnConnectCommand( + client: RpcClient, + data: { host: string, logblockid?: string } +): Promise +``` +- Explicitly connects to specified connection +- Used by "Reconnect" action in overlay +- Returns when connection succeeds or fails + +**ConnDisconnectCommand** +```typescript +ConnDisconnectCommand( + client: RpcClient, + connName: string +): Promise +``` +- Disconnects active connection +- Used by "Disconnect" action in connection modal +- Closes all shells/processes on that connection + +**SetMetaCommand** +```typescript +SetMetaCommand( + client: RpcClient, + data: { + oref: string, // WaveObject reference + meta: MetaType // Metadata updates + } +): Promise +``` +- Updates block metadata (including connection) +- Used when changing block's connection +- Triggers backend to switch connection context + +**SetConnectionsConfigCommand** +```typescript +SetConnectionsConfigCommand( + client: RpcClient, + data: { + host: string, // Connection name + metamaptype: any // Config updates + } +): Promise +``` +- Updates connection-level configuration +- Used to disable WSH (`conn:wshenabled: false`) +- Persists to config file + +### File Operations (Connection-Aware) + +**FileInfoCommand** +```typescript +FileInfoCommand( + client: RpcClient, + data: { info: { path: string } } +): Promise +``` +- Gets file metadata (size, type, permissions, etc.) +- Path format: `[connName]:[filepath]` (e.g., `user@host:~/file.txt`) +- Uses connection's WSH for remote files + +**FileReadCommand** +```typescript +FileReadCommand( + client: RpcClient, + data: { info: { path: string } } +): Promise +``` +- Reads file content as base64 +- Supports streaming for large files +- Remote files read via connection's WSH + +### Controller Commands (Indirect Connection Usage) + +**ControllerInputCommand** +```typescript +ControllerInputCommand( + client: RpcClient, + data: { blockid: string, inputdata64: string } +): Promise +``` +- Sends input to block's controller (shell) +- Controller uses block's connection for execution +- Base64-encoded to handle binary data + +**ControllerRestartCommand** +```typescript +ControllerRestartCommand( + client: RpcClient, + data: { blockid: string, force?: boolean } +): Promise +``` +- Restarts block's controller +- Backend checks connection status before starting +- If not connected, triggers connection first + +## Event-Driven Updates + +### Wave Event Subscriptions + +**Connection Status Updates:** +```typescript +waveEventSubscribe({ + eventType: "connstatus", + handler: (event) => { + const status: ConnStatus = event.data + updateConnStatusAtom(status.connection, status) + } +}) +``` +- Backend emits connection status changes +- Frontend updates corresponding atom +- All subscribed components re-render automatically + +**Configuration Updates:** +```typescript +waveEventSubscribe({ + eventType: "config", + handler: (event) => { + const fullConfig = event.data.fullconfig + globalStore.set(atoms.fullConfigAtom, fullConfig) + } +}) +``` +- Backend watches config files for changes +- Pushes updates to all connected frontends +- Connection configuration changes take effect immediately + +## Data Flow Patterns + +### Pattern 1: Changing Block Connection + +``` +User Action: Click connection button → select new connection + ↓ + ChangeConnectionBlockModal.changeConnection() + ↓ + RpcApi.SetMetaCommand({ connection: newConn }) + ↓ + Backend updates block metadata → emits waveobj:update + ↓ + Frontend WOS updates blockAtom + ↓ + View model connStatus atom recomputes + ↓ + ConnectionButton re-renders with new connection + ↓ + RpcApi.ConnEnsureCommand() ensures connected + ↓ + Backend triggers connection if needed + ↓ + Backend emits connstatus events as connection progresses + ↓ + Frontend updates ConnStatus atom ("connecting" → "connected") + ↓ + ConnectionButton shows connecting animation → connected state +``` + +### Pattern 2: Shell Process Lifecycle + +``` +User Action: Press Enter in disconnected terminal + ↓ + View model detects shellProcStatus == "init" or "done" + ↓ + forceRestartController() called + ↓ + RpcApi.ControllerRestartCommand() + ↓ + Backend ShellController.Run() starts + ↓ + CheckConnStatus() verifies connection + ↓ + If not connected: trigger connection + ↓ + (Frontend shows ConnStatusOverlay with "connecting") + ↓ + Connection succeeds → WSH available + ↓ + setupAndStartShellProcess() + ↓ + StartRemoteShellProc() with connection's SSH client + ↓ + Backend emits controllerstatus event + ↓ + Frontend updates shellProcStatus atom + ↓ + View model endIconButtons recomputes (restart button) + ↓ + Terminal ready for input +``` + +### Pattern 3: File Preview Over Connection + +``` +User Action: Open preview block with file path + ↓ + PreviewModel initialized with file path + ↓ + connection atom ensures connection + ↓ + RpcApi.ConnEnsureCommand(connName) + ↓ + Backend establishes connection if needed + ↓ + (Frontend shows ConnStatusOverlay if connecting) + ↓ + Connection ready + ↓ + statFile atom triggers FileInfoCommand + ↓ + Backend routes to connection's WSH + ↓ + WSH executes stat on remote file + ↓ + FileInfo returned to frontend + ↓ + PreviewModel determines if text/binary/streaming + ↓ + fullFile atom triggers FileReadCommand + ↓ + Backend streams file via WSH + ↓ + File content displayed in preview +``` + +## Connection Types and Behaviors + +### Local Connection + +**Connection Names:** +- `""` (empty string) +- `"local"` +- `"local:"` +- `"local:gitbash"` (Windows only) + +**Frontend Behavior:** +- No connection modal interaction needed +- ConnectionButton shows laptop icon (grey) +- No ConnStatusOverlay shown (always "connected") +- File paths used directly without connection prefix +- Shell processes spawn locally via `os/exec` + +**View Model Configuration:** +```typescript +connName = "" // or "local" or "local:gitbash" +connStatus = { + status: "connected", + connection: "", + connected: true, + activeconnnum: 0, // No color assignment + wshenabled: true // Local WSH always available +} +``` + +### SSH Connection + +**Connection Names:** +- Format: `"user@host"`, `"user@host:port"`, or config name +- Examples: `"ubuntu@192.168.1.10"`, `"myserver"`, `"deploy@prod:2222"` + +**Frontend Behavior:** +- ConnectionButton shows arrow icon with color +- Color cycles through 8 colors based on `activeconnnum` +- ConnStatusOverlay shown during connecting/error states +- File paths prefixed with connection: `user@host:~/file.txt` +- Modal allows reconnect/disconnect actions + +**Connection States:** +```typescript +// Connecting +connStatus = { + status: "connecting", + connection: "user@host", + connected: false, + activeconnnum: 3, + wshenabled: false // Not yet determined +} + +// Connected with WSH +connStatus = { + status: "connected", + connection: "user@host", + connected: true, + activeconnnum: 3, + wshenabled: true +} + +// Connected without WSH +connStatus = { + status: "connected", + connection: "user@host", + connected: true, + activeconnnum: 3, + wshenabled: false, + wsherror: "wsh installation failed: permission denied" +} + +// Error +connStatus = { + status: "error", + connection: "user@host", + connected: false, + activeconnnum: 3, + wshenabled: false, + error: "ssh: connection refused" +} +``` + +**WSH Errors:** +- Shown in ConnStatusOverlay +- "always disable wsh" button sets `conn:wshenabled: false` +- Terminal still works without WSH (limited features) +- Preview requires WSH (shows error if unavailable) + +### WSL Connection + +**Connection Names:** +- Format: `"wsl://[distro]"` +- Examples: `"wsl://Ubuntu"`, `"wsl://Debian"`, `"wsl://Ubuntu-20.04"` + +**Frontend Behavior:** +- Similar to SSH (colored arrow icon) +- Listed under "Local" section in modal +- No authentication prompts +- File paths: `wsl://Ubuntu:~/file.txt` + +**Backend Differences:** +- Uses `wsl.exe` instead of SSH +- No network overhead +- Predetermined domain socket path +- Simpler error handling + +### S3 Connection (Preview Only) + +**Connection Names:** +- Format: `"aws:[profile]"` +- Examples: `"aws:default"`, `"aws:production"` + +**Frontend Behavior:** +- Database icon (accent color) +- Only available in Preview view +- No shell/terminal support +- File paths: `aws:profile:/bucket/key` + +**View Model Settings:** +```typescript +// Terminal: S3 not shown +showS3 = atom(false) + +// Preview: S3 shown +showS3 = atom(true) +``` + +## Error Handling + +### Connection Errors + +**Authentication Failures:** +- Backend prompts for credentials via `userinput` events +- Frontend shows UserInputModal +- User enters password/passphrase +- Connection retries automatically + +**Network Errors:** +- ConnStatus.status becomes "error" +- ConnStatus.error contains message +- ConnStatusOverlay displays error +- "Reconnect" button triggers `ConnConnectCommand` + +**WSH Installation Errors:** +- ConnStatus.wsherror contains message +- ConnStatusOverlay shows separate WSH error section +- Options: + - Dismiss error (temporary) + - "always disable wsh" (permanent config change) + +### View Model Error Handling + +**Terminal View:** +```typescript +// Shell won't start if connection failed +endIconButtons = atom((get) => { + const connStatus = get(this.connStatus) + if (connStatus?.status != "connected") { + return [] // Hide restart button + } + // ... show restart button +}) + +// ConnStatusOverlay blocks terminal interaction +``` + +**Preview View:** +```typescript +// File operations return errors +errorMsgAtom = atom(null) as PrimitiveAtom + +statFile = atom(async (get) => { + try { + const fileInfo = await RpcApi.FileInfoCommand(...) + return fileInfo + } catch (e) { + globalStore.set(this.errorMsgAtom, { + status: "File Read Failed", + text: `${e}` + }) + throw e + } +}) + +// Error displayed in preview content area +``` + +## Best Practices + +### For View Model Authors + +1. **Use Connection Atoms:** + ```typescript + connStatus = atom((get) => { + const blockData = get(this.blockAtom) + const connName = blockData?.meta?.connection + return get(getConnStatusAtom(connName)) + }) + ``` + +2. **Check Connection Before Operations:** + ```typescript + if (connStatus?.status != "connected") { + return // Don't attempt operation + } + ``` + +3. **Use ConnEnsureCommand for File Ops:** + ```typescript + await RpcApi.ConnEnsureCommand(TabRpcClient, { + connname: connName, + logblockid: blockId // For better logging + }, { timeout: 60000 }) + ``` + +4. **Set manageConnection Appropriately:** + ```typescript + // Show connection button for views that need connections + manageConnection = atom(true) + + // Hide for views that don't use connections + manageConnection = atom(false) + ``` + +5. **Use filterOutNowsh for WSH Requirements:** + ```typescript + // Filter connections without WSH (file ops, etc.) + filterOutNowsh = atom(true) + + // Allow all connections (basic shell) + filterOutNowsh = atom(false) + ``` + +### For RPC Command Usage + +1. **Always Handle Errors:** + ```typescript + try { + await RpcApi.ConnConnectCommand(...) + } catch (e) { + console.error("Connection failed:", e) + // Update UI to show error + } + ``` + +2. **Use Appropriate Timeouts:** + ```typescript + // Connection operations: longer timeout + { timeout: 60000 } // 60 seconds + + // List operations: shorter timeout + { timeout: 2000 } // 2 seconds + ``` + +3. **Batch Related Operations:** + ```typescript + // Good: Single SetMetaCommand with all changes + await RpcApi.SetMetaCommand(TabRpcClient, { + oref: blockRef, + meta: { + connection: newConn, + file: newPath, + "cmd:cwd": null + } + }) + + // Bad: Multiple SetMetaCommand calls + ``` + +## Summary + +The frontend connection architecture is **reactive and declarative**: + +1. **Backend owns connection state** - All connection management happens in Go +2. **Frontend observes state** - Jotai atoms mirror backend state +3. **User actions trigger backend** - RPC commands initiate backend operations +4. **Events flow back to frontend** - Backend pushes updates via wave events +5. **View models isolate concerns** - Each view manages its own connection needs +6. **Block controllers bridge the gap** - Backend controllers use connections for process execution + +This architecture ensures: +- **Consistency** - Single source of truth (backend) +- **Reactivity** - UI updates automatically with state changes +- **Separation** - Frontend doesn't manage connection lifecycle +- **Flexibility** - Views can easily add connection support +- **Robustness** - Errors handled at appropriate layers \ No newline at end of file diff --git a/frontend/app/aipanel/aipanel.tsx b/frontend/app/aipanel/aipanel.tsx index 440ffd4a4..504910a44 100644 --- a/frontend/app/aipanel/aipanel.tsx +++ b/frontend/app/aipanel/aipanel.tsx @@ -6,7 +6,7 @@ import { waveAIHasSelection } from "@/app/aipanel/waveai-focus-utils"; import { ErrorBoundary } from "@/app/element/errorboundary"; import { atoms, getSettingsKeyAtom } from "@/app/store/global"; import { globalStore } from "@/app/store/jotaiStore"; -import { useTabModel } from "@/app/store/tab-model"; +import { maybeUseTabModel } from "@/app/store/tab-model"; import { checkKeyPressed, keydownWrapper } from "@/util/keyutil"; import { isMacOS, isWindows } from "@/util/platformutil"; import { cn } from "@/util/util"; @@ -255,7 +255,7 @@ const AIPanelComponentInner = memo(() => { const isFocused = jotai.useAtomValue(model.isWaveAIFocusedAtom); const telemetryEnabled = jotai.useAtomValue(getSettingsKeyAtom("telemetry:enabled")) ?? false; const isPanelVisible = jotai.useAtomValue(model.getPanelVisibleAtom()); - const tabModel = useTabModel(); + const tabModel = maybeUseTabModel(); const defaultMode = jotai.useAtomValue(getSettingsKeyAtom("waveai:defaultmode")) ?? "waveai@balanced"; const aiModeConfigs = jotai.useAtomValue(model.aiModeConfigs); diff --git a/frontend/app/aipanel/aipanelinput.tsx b/frontend/app/aipanel/aipanelinput.tsx index 2ee65a609..ec52ca0d1 100644 --- a/frontend/app/aipanel/aipanelinput.tsx +++ b/frontend/app/aipanel/aipanelinput.tsx @@ -169,24 +169,35 @@ export const AIPanelInput = memo(({ onSubmit, status, model }: AIPanelInputProps - - + + ) : ( + + - + + + )} diff --git a/frontend/app/aipanel/waveai-model.tsx b/frontend/app/aipanel/waveai-model.tsx index bbe4b095b..532442c6c 100644 --- a/frontend/app/aipanel/waveai-model.tsx +++ b/frontend/app/aipanel/waveai-model.tsx @@ -42,12 +42,12 @@ export interface DroppedFile { export class WaveAIModel { private static instance: WaveAIModel | null = null; - private inputRef: React.RefObject | null = null; - private scrollToBottomCallback: (() => void) | null = null; - private useChatSendMessage: UseChatSendMessageType | null = null; - private useChatSetMessages: UseChatSetMessagesType | null = null; - private useChatStatus: ChatStatus = "ready"; - private useChatStop: (() => void) | null = null; + inputRef: React.RefObject | null = null; + scrollToBottomCallback: (() => void) | null = null; + useChatSendMessage: UseChatSendMessageType | null = null; + useChatSetMessages: UseChatSetMessagesType | null = null; + useChatStatus: ChatStatus = "ready"; + useChatStop: (() => void) | null = null; // Used for injecting Wave-specific message data into DefaultChatTransport's prepareSendMessagesRequest realMessage: AIMessage | null = null; orefContext: ORef; @@ -324,6 +324,29 @@ export class WaveAIModel { } } + async reloadChatFromBackend(chatIdValue: string): Promise { + const chatData = await RpcApi.GetWaveAIChatCommand(TabRpcClient, { chatid: chatIdValue }); + const messages: UIMessage[] = chatData?.messages ?? []; + globalStore.set(this.isChatEmptyAtom, messages.length === 0); + return messages as WaveUIMessage[]; + } + + async stopResponse() { + this.useChatStop?.(); + await new Promise((resolve) => setTimeout(resolve, 500)); + + const chatIdValue = globalStore.get(this.chatId); + if (!chatIdValue) { + return; + } + try { + const messages = await this.reloadChatFromBackend(chatIdValue); + this.useChatSetMessages?.(messages); + } catch (error) { + console.error("Failed to reload chat after stop:", error); + } + } + getAndClearMessage(): AIMessage | null { const msg = this.realMessage; this.realMessage = null; @@ -448,10 +471,7 @@ export class WaveAIModel { } try { - const chatData = await RpcApi.GetWaveAIChatCommand(TabRpcClient, { chatid: chatIdValue }); - const messages: UIMessage[] = chatData?.messages ?? []; - globalStore.set(this.isChatEmptyAtom, messages.length === 0); - return messages as WaveUIMessage[]; // this is safe just different RPC type vs the FE type, but they are compatible + return await this.reloadChatFromBackend(chatIdValue); } catch (error) { console.error("Failed to load chat:", error); this.setError("Failed to load chat. Starting new chat..."); diff --git a/frontend/app/app.tsx b/frontend/app/app.tsx index 45723fe56..46048f3c9 100644 --- a/frontend/app/app.tsx +++ b/frontend/app/app.tsx @@ -3,6 +3,7 @@ import { ClientModel } from "@/app/store/client-model"; import { GlobalModel } from "@/app/store/global-model"; +import { getTabModelByTabId, TabModelContext } from "@/app/store/tab-model"; import { Workspace } from "@/app/workspace/workspace"; import { ContextMenuModel } from "@/store/contextmenu"; import { atoms, createBlock, getSettingsPrefixAtom, globalStore, isDev, removeFlashError } from "@/store/global"; @@ -31,12 +32,15 @@ const dlog = debug("wave:app"); const focusLog = debug("wave:focus"); const App = ({ onFirstRender }: { onFirstRender: () => void }) => { + const tabId = useAtomValue(atoms.staticTabId); useEffect(() => { onFirstRender(); }, []); return ( - + + + ); }; diff --git a/frontend/app/store/tab-model.ts b/frontend/app/store/tab-model.ts index daded66a0..ec5ab94c1 100644 --- a/frontend/app/store/tab-model.ts +++ b/frontend/app/store/tab-model.ts @@ -7,9 +7,9 @@ import { globalStore } from "./jotaiStore"; import * as WOS from "./wos"; const tabModelCache = new Map(); -const activeTabIdAtom = atom(null) as PrimitiveAtom; +export const activeTabIdAtom = atom(null) as PrimitiveAtom; -class TabModel { +export class TabModel { tabId: string; tabAtom: Atom; tabNumBlocksAtom: Atom; @@ -40,7 +40,7 @@ class TabModel { } } -function getTabModelByTabId(tabId: string): TabModel { +export function getTabModelByTabId(tabId: string): TabModel { let model = tabModelCache.get(tabId); if (model == null) { model = new TabModel(tabId); @@ -49,7 +49,7 @@ function getTabModelByTabId(tabId: string): TabModel { return model; } -function getActiveTabModel(): TabModel | null { +export function getActiveTabModel(): TabModel | null { const activeTabId = globalStore.get(activeTabIdAtom); if (activeTabId == null) { return null; @@ -57,9 +57,9 @@ function getActiveTabModel(): TabModel | null { return getTabModelByTabId(activeTabId); } -const TabModelContext = createContext(undefined); +export const TabModelContext = createContext(undefined); -function useTabModel(): TabModel { +export function useTabModel(): TabModel { const model = useContext(TabModelContext); if (model == null) { throw new Error("useTabModel must be used within a TabModelProvider"); @@ -67,4 +67,6 @@ function useTabModel(): TabModel { return model; } -export { activeTabIdAtom, getActiveTabModel, getTabModelByTabId, TabModel, TabModelContext, useTabModel }; +export function maybeUseTabModel(): TabModel { + return useContext(TabModelContext); +} diff --git a/frontend/app/workspace/workspace.tsx b/frontend/app/workspace/workspace.tsx index 9a0ec1431..fb1d78668 100644 --- a/frontend/app/workspace/workspace.tsx +++ b/frontend/app/workspace/workspace.tsx @@ -7,7 +7,6 @@ import { CenteredDiv } from "@/app/element/quickelems"; import { ModalsRenderer } from "@/app/modals/modalsrenderer"; import { TabBar } from "@/app/tab/tabbar"; import { TabContent } from "@/app/tab/tabcontent"; -import { getTabModelByTabId, TabModelContext } from "@/app/store/tab-model"; import { Widgets } from "@/app/workspace/widgets"; import { WorkspaceLayoutModel } from "@/app/workspace/workspace-layout-model"; import { atoms, getApi } from "@/store/global"; @@ -70,7 +69,7 @@ const WorkspaceElem = memo(() => { className="overflow-hidden" >
- + {tabId !== "" && }
@@ -79,9 +78,7 @@ const WorkspaceElem = memo(() => { No Active Tab ) : (
- - - +
)} diff --git a/pkg/aiusechat/chatstore/chatstore.go b/pkg/aiusechat/chatstore/chatstore.go index 7625a0533..4abe26ba6 100644 --- a/pkg/aiusechat/chatstore/chatstore.go +++ b/pkg/aiusechat/chatstore/chatstore.go @@ -5,6 +5,7 @@ package chatstore import ( "fmt" + "slices" "sync" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" @@ -109,3 +110,20 @@ func (cs *ChatStore) PostMessage(chatId string, aiOpts *uctypes.AIOptsType, mess return nil } + +func (cs *ChatStore) RemoveMessage(chatId string, messageId string) bool { + cs.lock.Lock() + defer cs.lock.Unlock() + + chat := cs.chats[chatId] + if chat == nil { + return false + } + + initialLen := len(chat.NativeMessages) + chat.NativeMessages = slices.DeleteFunc(chat.NativeMessages, func(msg uctypes.GenAIMessage) bool { + return msg.GetMessageId() == messageId + }) + + return len(chat.NativeMessages) < initialLen +} diff --git a/pkg/aiusechat/gemini/gemini-backend.go b/pkg/aiusechat/gemini/gemini-backend.go index 645588146..23a331ded 100644 --- a/pkg/aiusechat/gemini/gemini-backend.go +++ b/pkg/aiusechat/gemini/gemini-backend.go @@ -42,45 +42,6 @@ func ensureAltSse(endpoint string) (string, error) { return endpoint, nil } -// UpdateToolUseData updates the tool use data for a specific tool call in the chat -func UpdateToolUseData(chatId string, toolCallId string, toolUseData uctypes.UIMessageDataToolUse) error { - chat := chatstore.DefaultChatStore.Get(chatId) - if chat == nil { - return fmt.Errorf("chat not found: %s", chatId) - } - - for _, genMsg := range chat.NativeMessages { - chatMsg, ok := genMsg.(*GeminiChatMessage) - if !ok { - continue - } - - for i, part := range chatMsg.Parts { - if part.FunctionCall != nil && part.ToolUseData != nil && part.ToolUseData.ToolCallId == toolCallId { - // Update the message with new tool use data - updatedMsg := &GeminiChatMessage{ - MessageId: chatMsg.MessageId, - Role: chatMsg.Role, - Parts: make([]GeminiMessagePart, len(chatMsg.Parts)), - Usage: chatMsg.Usage, - } - copy(updatedMsg.Parts, chatMsg.Parts) - updatedMsg.Parts[i].ToolUseData = &toolUseData - - aiOpts := &uctypes.AIOptsType{ - APIType: chat.APIType, - Model: chat.Model, - APIVersion: chat.APIVersion, - } - - return chatstore.DefaultChatStore.PostMessage(chatId, aiOpts, updatedMsg) - } - } - } - - return fmt.Errorf("tool call with ID %s not found in chat %s", toolCallId, chatId) -} - // appendPartToLastUserMessage appends a text part to the last user message in the contents slice func appendPartToLastUserMessage(contents []GeminiContent, text string) { for i := len(contents) - 1; i >= 0; i-- { @@ -347,6 +308,14 @@ func processGeminiStream( if errors.Is(err, io.EOF) { break } + if sseHandler.Err() != nil { + partialMsg := extractPartialGeminiMessage(msgID, textBuilder.String()) + return &uctypes.WaveStopReason{ + Kind: uctypes.StopKindCanceled, + ErrorType: "client_disconnect", + ErrorText: "client disconnected", + }, partialMsg, nil + } _ = sseHandler.AiMsgError(fmt.Sprintf("stream decode error: %v", err)) return &uctypes.WaveStopReason{ Kind: uctypes.StopKindError, @@ -512,3 +481,19 @@ func processGeminiStream( return stopReason, assistantMsg, nil } + +func extractPartialGeminiMessage(msgID string, text string) *GeminiChatMessage { + if text == "" { + return nil + } + + return &GeminiChatMessage{ + MessageId: msgID, + Role: "model", + Parts: []GeminiMessagePart{ + { + Text: text, + }, + }, + } +} diff --git a/pkg/aiusechat/gemini/gemini-convertmessage.go b/pkg/aiusechat/gemini/gemini-convertmessage.go index dc43da422..02cd809c4 100644 --- a/pkg/aiusechat/gemini/gemini-convertmessage.go +++ b/pkg/aiusechat/gemini/gemini-convertmessage.go @@ -8,10 +8,12 @@ import ( "encoding/json" "fmt" "log" + "slices" "strings" "github.com/google/uuid" "github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil" + "github.com/wavetermdev/waveterm/pkg/aiusechat/chatstore" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" "github.com/wavetermdev/waveterm/pkg/util/utilfn" ) @@ -416,3 +418,91 @@ func GetFunctionCallInputByToolCallId(aiChat uctypes.AIChat, toolCallId string) } return nil } + +// UpdateToolUseData updates the tool use data for a specific tool call in the chat +func UpdateToolUseData(chatId string, toolCallId string, toolUseData uctypes.UIMessageDataToolUse) error { + chat := chatstore.DefaultChatStore.Get(chatId) + if chat == nil { + return fmt.Errorf("chat not found: %s", chatId) + } + + for _, genMsg := range chat.NativeMessages { + chatMsg, ok := genMsg.(*GeminiChatMessage) + if !ok { + continue + } + + for i, part := range chatMsg.Parts { + if part.FunctionCall != nil && part.ToolUseData != nil && part.ToolUseData.ToolCallId == toolCallId { + // Update the message with new tool use data + updatedMsg := &GeminiChatMessage{ + MessageId: chatMsg.MessageId, + Role: chatMsg.Role, + Parts: make([]GeminiMessagePart, len(chatMsg.Parts)), + Usage: chatMsg.Usage, + } + copy(updatedMsg.Parts, chatMsg.Parts) + updatedMsg.Parts[i].ToolUseData = &toolUseData + + aiOpts := &uctypes.AIOptsType{ + APIType: chat.APIType, + Model: chat.Model, + APIVersion: chat.APIVersion, + } + + return chatstore.DefaultChatStore.PostMessage(chatId, aiOpts, updatedMsg) + } + } + } + + return fmt.Errorf("tool call with ID %s not found in chat %s", toolCallId, chatId) +} + +func RemoveToolUseCall(chatId string, toolCallId string) error { + chat := chatstore.DefaultChatStore.Get(chatId) + if chat == nil { + return fmt.Errorf("chat not found: %s", chatId) + } + + for _, genMsg := range chat.NativeMessages { + chatMsg, ok := genMsg.(*GeminiChatMessage) + if !ok { + continue + } + + partIndex := -1 + for i, part := range chatMsg.Parts { + if part.FunctionCall != nil && part.ToolUseData != nil && part.ToolUseData.ToolCallId == toolCallId { + partIndex = i + break + } + } + + if partIndex == -1 { + continue + } + + updatedMsg := &GeminiChatMessage{ + MessageId: chatMsg.MessageId, + Role: chatMsg.Role, + Parts: slices.Delete(slices.Clone(chatMsg.Parts), partIndex, partIndex+1), + Usage: chatMsg.Usage, + } + + if len(updatedMsg.Parts) == 0 { + chatstore.DefaultChatStore.RemoveMessage(chatId, chatMsg.MessageId) + } else { + aiOpts := &uctypes.AIOptsType{ + APIType: chat.APIType, + Model: chat.Model, + APIVersion: chat.APIVersion, + } + if err := chatstore.DefaultChatStore.PostMessage(chatId, aiOpts, updatedMsg); err != nil { + return err + } + } + return nil + } + + return nil +} diff --git a/pkg/aiusechat/openai/openai-backend.go b/pkg/aiusechat/openai/openai-backend.go index bca2ea44f..dc9172341 100644 --- a/pkg/aiusechat/openai/openai-backend.go +++ b/pkg/aiusechat/openai/openai-backend.go @@ -388,12 +388,13 @@ const ( ) type openaiBlockState struct { - kind openaiBlockKind - localID string // For SSE streaming to UI - toolCallID string // For function calls - toolName string // For function calls - summaryCount int // For reasoning: number of summary parts seen - partialJSON []byte // For function calls: accumulated JSON arguments + kind openaiBlockKind + localID string // For SSE streaming to UI + toolCallID string // For function calls + toolName string // For function calls + summaryCount int // For reasoning: number of summary parts seen + partialJSON []byte // For function calls: accumulated JSON arguments + accumulatedText string // For text blocks: accumulated text content } type openaiStreamingState struct { @@ -438,6 +439,27 @@ func UpdateToolUseData(chatId string, callId string, newToolUseData uctypes.UIMe return fmt.Errorf("function call with callId %s not found in chat %s", callId, chatId) } +func RemoveToolUseCall(chatId string, callId string) error { + chat := chatstore.DefaultChatStore.Get(chatId) + if chat == nil { + return fmt.Errorf("chat not found: %s", chatId) + } + + for _, genMsg := range chat.NativeMessages { + chatMsg, ok := genMsg.(*OpenAIChatMessage) + if !ok { + continue + } + + if chatMsg.FunctionCall != nil && chatMsg.FunctionCall.CallId == callId { + chatstore.DefaultChatStore.RemoveMessage(chatId, chatMsg.MessageId) + return nil + } + } + + return nil +} + func RunOpenAIChatStep( ctx context.Context, sse *sse.SSEHandlerCh, @@ -612,16 +634,6 @@ func handleOpenAIStreamingResp(ctx context.Context, sse *sse.SSEHandlerCh, decod // SSE event processing loop for { - // Check for context cancellation - if err := ctx.Err(); err != nil { - _ = sse.AiMsgError("request cancelled") - return &uctypes.WaveStopReason{ - Kind: uctypes.StopKindCanceled, - ErrorType: "cancelled", - ErrorText: "request cancelled", - }, rtnMessages - } - event, err := decoder.Decode() if err != nil { if errors.Is(err, io.EOF) { @@ -633,6 +645,19 @@ func handleOpenAIStreamingResp(ctx context.Context, sse *sse.SSEHandlerCh, decod ErrorText: "stream ended unexpectedly without completion", }, rtnMessages } + // Check if client disconnected + if sse.Err() != nil { + // SSE connection broken (client stopped/disconnected) + partialMessages := extractPartialTextFromState(state) + if partialMessages != nil { + rtnMessages = append(rtnMessages, partialMessages...) + } + return &uctypes.WaveStopReason{ + Kind: uctypes.StopKindCanceled, + ErrorType: "client_disconnect", + ErrorText: "client disconnected", + }, rtnMessages + } // transport error mid-stream _ = sse.AiMsgError(err.Error()) return &uctypes.WaveStopReason{ @@ -643,10 +668,14 @@ func handleOpenAIStreamingResp(ctx context.Context, sse *sse.SSEHandlerCh, decod } if finalStopReason, finalMessages := handleOpenAIEvent(event, sse, state, cont); finalStopReason != nil { - // Either error or response.completed triggered return rtnStopReason = finalStopReason if finalMessages != nil { rtnMessages = finalMessages + } else if finalStopReason.Kind == uctypes.StopKindCanceled { + partialMessages := extractPartialTextFromState(state) + if partialMessages != nil { + rtnMessages = append(rtnMessages, partialMessages...) + } } return finalStopReason, rtnMessages } @@ -655,6 +684,34 @@ func handleOpenAIStreamingResp(ctx context.Context, sse *sse.SSEHandlerCh, decod // unreachable } +// extractPartialTextFromState extracts accumulated text from streaming state when client disconnects +func extractPartialTextFromState(state *openaiStreamingState) []*OpenAIChatMessage { + var textContent []OpenAIMessageContent + + for _, blockState := range state.blockMap { + if blockState.kind == openaiBlockText && blockState.accumulatedText != "" { + textContent = append(textContent, OpenAIMessageContent{ + Type: "output_text", + Text: blockState.accumulatedText, + }) + } + } + + if len(textContent) == 0 { + return nil + } + + assistantMessage := &OpenAIChatMessage{ + MessageId: uuid.New().String(), + Message: &OpenAIMessage{ + Role: "assistant", + Content: textContent, + }, + } + + return []*OpenAIChatMessage{assistantMessage} +} + // handleOpenAIEvent processes one SSE event block. It may emit SSE parts // and/or return a StopReason and final message when the stream is complete. // @@ -667,6 +724,14 @@ func handleOpenAIEvent( state *openaiStreamingState, cont *uctypes.WaveContinueResponse, ) (final *uctypes.WaveStopReason, messages []*OpenAIChatMessage) { + if err := sse.Err(); err != nil { + return &uctypes.WaveStopReason{ + Kind: uctypes.StopKindCanceled, + ErrorType: "client_disconnect", + ErrorText: "client disconnected", + }, nil + } + eventName := event.Event() data := event.Data() @@ -770,6 +835,7 @@ func handleOpenAIEvent( } if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockText { + st.accumulatedText += ev.Delta _ = sse.AiMsgTextDelta(st.localID, ev.Delta) } return nil, nil diff --git a/pkg/aiusechat/openaichat/openaichat-backend.go b/pkg/aiusechat/openaichat/openaichat-backend.go index d9aef7961..04df1a65d 100644 --- a/pkg/aiusechat/openaichat/openaichat-backend.go +++ b/pkg/aiusechat/openaichat/openaichat-backend.go @@ -131,6 +131,14 @@ func processChatStream( if errors.Is(err, io.EOF) { break } + if sseHandler.Err() != nil { + partialMsg := extractPartialTextMessage(msgID, textBuilder.String()) + return &uctypes.WaveStopReason{ + Kind: uctypes.StopKindCanceled, + ErrorType: "client_disconnect", + ErrorText: "client disconnected", + }, partialMsg, nil + } _ = sseHandler.AiMsgError(err.Error()) return &uctypes.WaveStopReason{ Kind: uctypes.StopKindError, @@ -255,3 +263,17 @@ func processChatStream( return stopReason, assistantMsg, nil } + +func extractPartialTextMessage(msgID string, text string) *StoredChatMessage { + if text == "" { + return nil + } + + return &StoredChatMessage{ + MessageId: msgID, + Message: ChatRequestMessage{ + Role: "assistant", + Content: text, + }, + } +} diff --git a/pkg/aiusechat/openaichat/openaichat-convertmessage.go b/pkg/aiusechat/openaichat/openaichat-convertmessage.go index fb8d0b8d0..c2a7dcd07 100644 --- a/pkg/aiusechat/openaichat/openaichat-convertmessage.go +++ b/pkg/aiusechat/openaichat/openaichat-convertmessage.go @@ -11,6 +11,7 @@ import ( "fmt" "log" "net/http" + "slices" "strings" "github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil" @@ -358,3 +359,38 @@ func UpdateToolUseData(chatId string, callId string, newToolUseData uctypes.UIMe return fmt.Errorf("tool call with callId %s not found in chat %s", callId, chatId) } + +func RemoveToolUseCall(chatId string, callId string) error { + chat := chatstore.DefaultChatStore.Get(chatId) + if chat == nil { + return fmt.Errorf("chat not found: %s", chatId) + } + + for _, genMsg := range chat.NativeMessages { + chatMsg, ok := genMsg.(*StoredChatMessage) + if !ok { + continue + } + idx := chatMsg.Message.FindToolCallIndex(callId) + if idx == -1 { + continue + } + updatedMsg := chatMsg.Copy() + updatedMsg.Message.ToolCalls = slices.Delete(updatedMsg.Message.ToolCalls, idx, idx+1) + if len(updatedMsg.Message.ToolCalls) == 0 { + chatstore.DefaultChatStore.RemoveMessage(chatId, chatMsg.MessageId) + } else { + aiOpts := &uctypes.AIOptsType{ + APIType: chat.APIType, + Model: chat.Model, + APIVersion: chat.APIVersion, + } + if err := chatstore.DefaultChatStore.PostMessage(chatId, aiOpts, updatedMsg); err != nil { + return err + } + } + return nil + } + + return nil +} diff --git a/pkg/aiusechat/usechat-backend.go b/pkg/aiusechat/usechat-backend.go index 6ae1d9466..cb380a457 100644 --- a/pkg/aiusechat/usechat-backend.go +++ b/pkg/aiusechat/usechat-backend.go @@ -32,6 +32,10 @@ type UseChatBackend interface { // This is used to update the UI state for tool execution (approval status, results, etc.) UpdateToolUseData(chatId string, toolCallId string, toolUseData uctypes.UIMessageDataToolUse) error + // RemoveToolUseCall removes a tool use call from the chat's native messages. + // This is used to clean up incomplete or canceled tool calls when stopping execution. + RemoveToolUseCall(chatId string, toolCallId string) error + // ConvertToolResultsToNativeChatMessage converts tool execution results into native chat messages // that can be sent back to the AI backend. Returns a slice of messages (some backends may // require multiple messages per tool result). @@ -94,6 +98,10 @@ func (b *openaiResponsesBackend) UpdateToolUseData(chatId string, toolCallId str return openai.UpdateToolUseData(chatId, toolCallId, toolUseData) } +func (b *openaiResponsesBackend) RemoveToolUseCall(chatId string, toolCallId string) error { + return openai.RemoveToolUseCall(chatId, toolCallId) +} + func (b *openaiResponsesBackend) ConvertToolResultsToNativeChatMessage(toolResults []uctypes.AIToolResult) ([]uctypes.GenAIMessage, error) { msgs, err := openai.ConvertToolResultsToOpenAIChatMessage(toolResults) if err != nil { @@ -148,6 +156,10 @@ func (b *openaiCompletionsBackend) UpdateToolUseData(chatId string, toolCallId s return openaichat.UpdateToolUseData(chatId, toolCallId, toolUseData) } +func (b *openaiCompletionsBackend) RemoveToolUseCall(chatId string, toolCallId string) error { + return openaichat.RemoveToolUseCall(chatId, toolCallId) +} + func (b *openaiCompletionsBackend) ConvertToolResultsToNativeChatMessage(toolResults []uctypes.AIToolResult) ([]uctypes.GenAIMessage, error) { return openaichat.ConvertToolResultsToNativeChatMessage(toolResults) } @@ -181,6 +193,10 @@ func (b *anthropicBackend) UpdateToolUseData(chatId string, toolCallId string, t return fmt.Errorf("UpdateToolUseData not implemented for anthropic backend") } +func (b *anthropicBackend) RemoveToolUseCall(chatId string, toolCallId string) error { + return fmt.Errorf("RemoveToolUseCall not implemented for anthropic backend") +} + func (b *anthropicBackend) ConvertToolResultsToNativeChatMessage(toolResults []uctypes.AIToolResult) ([]uctypes.GenAIMessage, error) { msg, err := anthropic.ConvertToolResultsToAnthropicChatMessage(toolResults) if err != nil { @@ -221,6 +237,10 @@ func (b *geminiBackend) UpdateToolUseData(chatId string, toolCallId string, tool return gemini.UpdateToolUseData(chatId, toolCallId, toolUseData) } +func (b *geminiBackend) RemoveToolUseCall(chatId string, toolCallId string) error { + return gemini.RemoveToolUseCall(chatId, toolCallId) +} + func (b *geminiBackend) ConvertToolResultsToNativeChatMessage(toolResults []uctypes.AIToolResult) ([]uctypes.GenAIMessage, error) { msg, err := gemini.ConvertToolResultsToGeminiChatMessage(toolResults) if err != nil { diff --git a/pkg/aiusechat/usechat.go b/pkg/aiusechat/usechat.go index 08e675c43..8d8fcf644 100644 --- a/pkg/aiusechat/usechat.go +++ b/pkg/aiusechat/usechat.go @@ -43,8 +43,7 @@ var ( globalRateLimitInfo = &uctypes.RateLimitInfo{Unknown: true} rateLimitLock sync.Mutex - activeToolMap = ds.MakeSyncMap[bool]() // key is toolcallid - activeChats = ds.MakeSyncMap[bool]() // key is chatid + activeChats = ds.MakeSyncMap[bool]() // key is chatid ) func getSystemPrompt(apiType string, model string, isBuilder bool, hasToolsCapability bool, widgetAccess bool) []string { @@ -252,7 +251,7 @@ func processToolCallInternal(backend UseChatBackend, toolCall uctypes.WaveToolCa if toolCall.ToolUseData.Approval == uctypes.ApprovalNeedsApproval { log.Printf(" waiting for approval...\n") - approval, err := WaitForToolApproval(context.Background(), toolCall.ID) + approval, err := WaitForToolApproval(sseHandler.Context(), toolCall.ID) if err != nil || approval == "" { approval = uctypes.ApprovalCanceled } @@ -321,12 +320,7 @@ func processToolCall(backend UseChatBackend, toolCall uctypes.WaveToolCall, chat return result } -func processToolCalls(backend UseChatBackend, stopReason *uctypes.WaveStopReason, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) { - for _, toolCall := range stopReason.ToolCalls { - activeToolMap.Set(toolCall.ID, true) - defer activeToolMap.Delete(toolCall.ID) - } - +func processAllToolCalls(backend UseChatBackend, stopReason *uctypes.WaveStopReason, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) { // Create and send all data-tooluse packets at the beginning for i := range stopReason.ToolCalls { toolCall := &stopReason.ToolCalls[i] @@ -351,17 +345,37 @@ func processToolCalls(backend UseChatBackend, stopReason *uctypes.WaveStopReason var toolResults []uctypes.AIToolResult for _, toolCall := range stopReason.ToolCalls { + if sseHandler.Err() != nil { + log.Printf("AI tool processing stopped: %v\n", sseHandler.Err()) + break + } result := processToolCall(backend, toolCall, chatOpts, sseHandler, metrics) toolResults = append(toolResults, result) + } + + // Cleanup: unregister approvals, remove incomplete/canceled tool calls, and filter results + var filteredResults []uctypes.AIToolResult + for i, toolCall := range stopReason.ToolCalls { UnregisterToolApproval(toolCall.ID) + hasResult := i < len(toolResults) + shouldRemove := !hasResult || (toolCall.ToolUseData != nil && toolCall.ToolUseData.Approval == uctypes.ApprovalCanceled) + if shouldRemove { + backend.RemoveToolUseCall(chatOpts.ChatId, toolCall.ID) + } else if hasResult { + filteredResults = append(filteredResults, toolResults[i]) + } } - toolResultMsgs, err := backend.ConvertToolResultsToNativeChatMessage(toolResults) - if err != nil { - log.Printf("Failed to convert tool results to native chat messages: %v", err) - } else { - for _, msg := range toolResultMsgs { - chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, msg) + if len(filteredResults) > 0 { + toolResultMsgs, err := backend.ConvertToolResultsToNativeChatMessage(filteredResults) + if err != nil { + log.Printf("Failed to convert tool results to native chat messages: %v", err) + } else { + for _, msg := range toolResultMsgs { + if err := chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, msg); err != nil { + log.Printf("Failed to post tool result message: %v", err) + } + } } } } @@ -419,6 +433,9 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseCha metrics.PremiumReqCount++ } } + if stopReason != nil { + logutil.DevPrintf("stopreason: %s (%s) (%s) (%s)\n", stopReason.Kind, stopReason.ErrorText, stopReason.ErrorType, stopReason.RawReason) + } if len(rtnMessages) > 0 { usage := getUsage(rtnMessages) log.Printf("usage: input=%d output=%d websearch=%d\n", usage.InputTokens, usage.OutputTokens, usage.NativeWebSearchCount) @@ -441,7 +458,9 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseCha } for _, msg := range rtnMessages { if msg != nil { - chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, msg) + if err := chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, msg); err != nil { + log.Printf("Failed to post message: %v", err) + } } } firstStep = false @@ -455,7 +474,7 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseCha } if stopReason != nil && stopReason.Kind == uctypes.StopKindToolUse { metrics.ToolUseCount += len(stopReason.ToolCalls) - processToolCalls(backend, stopReason, chatOpts, sseHandler, metrics) + processAllToolCalls(backend, stopReason, chatOpts, sseHandler, metrics) cont = &uctypes.WaveContinueResponse{ Model: chatOpts.Config.Model, ContinueFromKind: uctypes.StopKindToolUse, diff --git a/pkg/web/sse/ssehandler.go b/pkg/web/sse/ssehandler.go index cdd055fbd..4012d716a 100644 --- a/pkg/web/sse/ssehandler.go +++ b/pkg/web/sse/ssehandler.go @@ -89,6 +89,10 @@ func MakeSSEHandlerCh(w http.ResponseWriter, ctx context.Context) *SSEHandlerCh } } +func (h *SSEHandlerCh) Context() context.Context { + return h.ctx +} + // SetupSSE configures the response headers and starts the writer goroutine func (h *SSEHandlerCh) SetupSSE() error { h.lock.Lock()