diff --git a/go/client.go b/go/client.go index 50e6b4cc..c88a68ac 100644 --- a/go/client.go +++ b/go/client.go @@ -41,6 +41,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/github/copilot-sdk/go/internal/embeddedcli" @@ -86,8 +87,10 @@ type Client struct { lifecycleHandlers []SessionLifecycleHandler typedLifecycleHandlers map[SessionLifecycleEventType][]SessionLifecycleHandler lifecycleHandlersMux sync.Mutex - processDone chan struct{} // closed when CLI process exits - processError error // set before processDone is closed + startStopMux sync.RWMutex // protects process and state during start/[force]stop + processDone chan struct{} + processErrorPtr *error + osProcess atomic.Pointer[os.Process] // RPC provides typed server-scoped RPC methods. // This field is nil until the client is connected via Start(). @@ -251,6 +254,9 @@ func parseCliUrl(url string) (string, int) { // } // // Now ready to create sessions func (c *Client) Start(ctx context.Context) error { + c.startStopMux.Lock() + defer c.startStopMux.Unlock() + if c.state == StateConnected { return nil } @@ -260,6 +266,7 @@ func (c *Client) Start(ctx context.Context) error { // Only start CLI server process if not connecting to external server if !c.isExternalServer { if err := c.startCLIServer(ctx); err != nil { + c.process = nil c.state = StateError return err } @@ -267,14 +274,16 @@ func (c *Client) Start(ctx context.Context) error { // Connect to the server if err := c.connectToServer(ctx); err != nil { + killErr := c.killProcess() c.state = StateError - return err + return errors.Join(err, killErr) } // Verify protocol version compatibility if err := c.verifyProtocolVersion(ctx); err != nil { + killErr := c.killProcess() c.state = StateError - return err + return errors.Join(err, killErr) } c.state = StateConnected @@ -316,13 +325,16 @@ func (c *Client) Stop() error { c.sessions = make(map[string]*Session) c.sessionsMux.Unlock() + c.startStopMux.Lock() + defer c.startStopMux.Unlock() + // Kill CLI process FIRST (this closes stdout and unblocks readLoop) - only if we spawned it if c.process != nil && !c.isExternalServer { - if err := c.process.Process.Kill(); err != nil { - errs = append(errs, fmt.Errorf("failed to kill CLI process: %w", err)) + if err := c.killProcess(); err != nil { + errs = append(errs, err) } - c.process = nil } + c.process = nil // Close external TCP connection if exists if c.isExternalServer && c.conn != nil { @@ -375,16 +387,27 @@ func (c *Client) Stop() error { // client.ForceStop() // } func (c *Client) ForceStop() { + // Kill the process without waiting for startStopMux, which Start may hold. + // This unblocks any I/O Start is doing (connect, version check). + if p := c.osProcess.Swap(nil); p != nil { + p.Kill() + } + // Clear sessions immediately without trying to destroy them c.sessionsMux.Lock() c.sessions = make(map[string]*Session) c.sessionsMux.Unlock() + c.startStopMux.Lock() + defer c.startStopMux.Unlock() + // Kill CLI process (only if we spawned it) + // This is a fallback in case the process wasn't killed above (e.g. if Start hadn't set + // osProcess yet), or if the process was restarted and osProcess now points to a new process. if c.process != nil && !c.isExternalServer { - c.process.Process.Kill() // Ignore errors - c.process = nil + _ = c.killProcess() // Ignore errors since we're force stopping } + c.process = nil // Close external TCP connection if exists if c.isExternalServer && c.conn != nil { @@ -886,6 +909,8 @@ func (c *Client) handleLifecycleEvent(event SessionLifecycleEvent) { // }) // } func (c *Client) State() ConnectionState { + c.startStopMux.RLock() + defer c.startStopMux.RUnlock() return c.state } @@ -1096,21 +1121,11 @@ func (c *Client) startCLIServer(ctx context.Context) error { return fmt.Errorf("failed to start CLI server: %w", err) } - // Monitor process exit to signal pending requests - c.processDone = make(chan struct{}) - go func() { - waitErr := c.process.Wait() - if waitErr != nil { - c.processError = fmt.Errorf("CLI process exited: %v", waitErr) - } else { - c.processError = fmt.Errorf("CLI process exited unexpectedly") - } - close(c.processDone) - }() + c.monitorProcess() // Create JSON-RPC client immediately c.client = jsonrpc2.NewClient(stdin, stdout) - c.client.SetProcessDone(c.processDone, &c.processError) + c.client.SetProcessDone(c.processDone, c.processErrorPtr) c.RPC = rpc.NewServerRpc(c.client) c.setupNotificationHandler() c.client.Start() @@ -1127,7 +1142,8 @@ func (c *Client) startCLIServer(ctx context.Context) error { return fmt.Errorf("failed to start CLI server: %w", err) } - // Wait for port announcement + c.monitorProcess() + scanner := bufio.NewScanner(stdout) timeout := time.After(10 * time.Second) portRegex := regexp.MustCompile(`listening on port (\d+)`) @@ -1135,14 +1151,19 @@ func (c *Client) startCLIServer(ctx context.Context) error { for { select { case <-timeout: - return fmt.Errorf("timeout waiting for CLI server to start") + killErr := c.killProcess() + return errors.Join(errors.New("timeout waiting for CLI server to start"), killErr) + case <-c.processDone: + killErr := c.killProcess() + return errors.Join(errors.New("CLI server process exited before reporting port"), killErr) default: if scanner.Scan() { line := scanner.Text() if matches := portRegex.FindStringSubmatch(line); len(matches) > 1 { port, err := strconv.Atoi(matches[1]) if err != nil { - return fmt.Errorf("failed to parse port: %w", err) + killErr := c.killProcess() + return errors.Join(fmt.Errorf("failed to parse port: %w", err), killErr) } c.actualPort = port return nil @@ -1153,6 +1174,39 @@ func (c *Client) startCLIServer(ctx context.Context) error { } } +func (c *Client) killProcess() error { + if p := c.osProcess.Swap(nil); p != nil { + if err := p.Kill(); err != nil { + return fmt.Errorf("failed to kill CLI process: %w", err) + } + } + c.process = nil + return nil +} + +// monitorProcess signals when the CLI process exits and captures any exit error. +// processError is intentionally a local: each process lifecycle gets its own +// error value, so goroutines from previous processes can't overwrite the +// current one. Closing the channel synchronizes with readers, guaranteeing +// they see the final processError value. +func (c *Client) monitorProcess() { + done := make(chan struct{}) + c.processDone = done + proc := c.process + c.osProcess.Store(proc.Process) + var processError error + c.processErrorPtr = &processError + go func() { + waitErr := proc.Wait() + if waitErr != nil { + processError = fmt.Errorf("CLI process exited: %w", waitErr) + } else { + processError = errors.New("CLI process exited unexpectedly") + } + close(done) + }() +} + // connectToServer establishes a connection to the server. func (c *Client) connectToServer(ctx context.Context) error { if c.useStdio { @@ -1184,6 +1238,9 @@ func (c *Client) connectViaTcp(ctx context.Context) error { // Create JSON-RPC client with the connection c.client = jsonrpc2.NewClient(conn, conn) + if c.processDone != nil { + c.client.SetProcessDone(c.processDone, c.processErrorPtr) + } c.RPC = rpc.NewServerRpc(c.client) c.setupNotificationHandler() c.client.Start() diff --git a/go/client_test.go b/go/client_test.go index 2d198f22..752bdc75 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "reflect" "regexp" + "sync" "testing" ) @@ -486,3 +487,44 @@ func TestClient_ResumeSession_RequiresPermissionHandler(t *testing.T) { } }) } + +func TestClient_StartStopRace(t *testing.T) { + cliPath := findCLIPathForTest() + if cliPath == "" { + t.Skip("CLI not found") + } + client := NewClient(&ClientOptions{CLIPath: cliPath}) + defer client.ForceStop() + errChan := make(chan error) + wg := sync.WaitGroup{} + for range 10 { + wg.Add(3) + go func() { + defer wg.Done() + if err := client.Start(t.Context()); err != nil { + select { + case errChan <- err: + default: + } + } + }() + go func() { + defer wg.Done() + if err := client.Stop(); err != nil { + select { + case errChan <- err: + default: + } + } + }() + go func() { + defer wg.Done() + client.ForceStop() + }() + } + wg.Wait() + close(errChan) + if err := <-errChan; err != nil { + t.Fatal(err) + } +} diff --git a/go/internal/jsonrpc2/jsonrpc2.go b/go/internal/jsonrpc2/jsonrpc2.go index 03cf49b3..09505c06 100644 --- a/go/internal/jsonrpc2/jsonrpc2.go +++ b/go/internal/jsonrpc2/jsonrpc2.go @@ -8,6 +8,7 @@ import ( "io" "reflect" "sync" + "sync/atomic" ) // Error represents a JSON-RPC error response @@ -54,7 +55,7 @@ type Client struct { mu sync.Mutex pendingRequests map[string]chan *Response requestHandlers map[string]RequestHandler - running bool + running atomic.Bool stopChan chan struct{} wg sync.WaitGroup processDone chan struct{} // closed when the underlying process exits @@ -97,17 +98,17 @@ func (c *Client) getProcessError() error { // Start begins listening for messages in a background goroutine func (c *Client) Start() { - c.running = true + c.running.Store(true) c.wg.Add(1) go c.readLoop() } // Stop stops the client and cleans up func (c *Client) Stop() { - if !c.running { + if !c.running.Load() { return } - c.running = false + c.running.Store(false) close(c.stopChan) // Close stdout to unblock the readLoop @@ -298,14 +299,14 @@ func (c *Client) readLoop() { reader := bufio.NewReader(c.stdout) - for c.running { + for c.running.Load() { // Read Content-Length header var contentLength int for { line, err := reader.ReadString('\n') if err != nil { // Only log unexpected errors (not EOF or closed pipe during shutdown) - if err != io.EOF && c.running { + if err != io.EOF && c.running.Load() { fmt.Printf("Error reading header: %v\n", err) } return diff --git a/go/test.sh b/go/test.sh old mode 100644 new mode 100755 index c3f33fb0..e1dd8aaa --- a/go/test.sh +++ b/go/test.sh @@ -8,7 +8,7 @@ echo # Check prerequisites if ! command -v go &> /dev/null; then - echo "❌ Go is not installed. Please install Go 1.21 or later." + echo "❌ Go is not installed. Please install Go 1.24 or later." echo " Visit: https://golang.org/dl/" exit 1 fi @@ -43,7 +43,7 @@ cd "$(dirname "$0")" echo "=== Running Go SDK E2E Tests ===" echo -go test -v ./... +go test -v ./... -race echo echo "✅ All tests passed!"