Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 81 additions & 24 deletions go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/github/copilot-sdk/go/internal/embeddedcli"
Expand Down Expand Up @@ -86,8 +87,10 @@ type Client struct {
lifecycleHandlers []SessionLifecycleHandler
typedLifecycleHandlers map[SessionLifecycleEventType][]SessionLifecycleHandler
lifecycleHandlersMux sync.Mutex
processDone chan struct{} // closed when CLI process exits
processError error // set before processDone is closed
startStopMux sync.RWMutex // protects process and state during start/[force]stop
processDone chan struct{}
processErrorPtr *error
osProcess atomic.Pointer[os.Process]

// RPC provides typed server-scoped RPC methods.
// This field is nil until the client is connected via Start().
Expand Down Expand Up @@ -251,6 +254,9 @@ func parseCliUrl(url string) (string, int) {
// }
// // Now ready to create sessions
func (c *Client) Start(ctx context.Context) error {
c.startStopMux.Lock()
defer c.startStopMux.Unlock()

if c.state == StateConnected {
return nil
}
Expand All @@ -260,21 +266,24 @@ func (c *Client) Start(ctx context.Context) error {
// Only start CLI server process if not connecting to external server
if !c.isExternalServer {
if err := c.startCLIServer(ctx); err != nil {
c.process = nil
c.state = StateError
return err
}
}

// Connect to the server
if err := c.connectToServer(ctx); err != nil {
killErr := c.killProcess()
c.state = StateError
return err
return errors.Join(err, killErr)
}

// Verify protocol version compatibility
if err := c.verifyProtocolVersion(ctx); err != nil {
killErr := c.killProcess()
c.state = StateError
return err
return errors.Join(err, killErr)
}

c.state = StateConnected
Expand Down Expand Up @@ -316,13 +325,16 @@ func (c *Client) Stop() error {
c.sessions = make(map[string]*Session)
c.sessionsMux.Unlock()

c.startStopMux.Lock()
defer c.startStopMux.Unlock()

// Kill CLI process FIRST (this closes stdout and unblocks readLoop) - only if we spawned it
if c.process != nil && !c.isExternalServer {
if err := c.process.Process.Kill(); err != nil {
errs = append(errs, fmt.Errorf("failed to kill CLI process: %w", err))
if err := c.killProcess(); err != nil {
errs = append(errs, err)
}
c.process = nil
}
c.process = nil

// Close external TCP connection if exists
if c.isExternalServer && c.conn != nil {
Expand Down Expand Up @@ -375,16 +387,27 @@ func (c *Client) Stop() error {
// client.ForceStop()
// }
func (c *Client) ForceStop() {
// Kill the process without waiting for startStopMux, which Start may hold.
// This unblocks any I/O Start is doing (connect, version check).
if p := c.osProcess.Swap(nil); p != nil {
p.Kill()
}

// Clear sessions immediately without trying to destroy them
c.sessionsMux.Lock()
c.sessions = make(map[string]*Session)
c.sessionsMux.Unlock()

c.startStopMux.Lock()
defer c.startStopMux.Unlock()

// Kill CLI process (only if we spawned it)
// This is a fallback in case the process wasn't killed above (e.g. if Start hadn't set
// osProcess yet), or if the process was restarted and osProcess now points to a new process.
if c.process != nil && !c.isExternalServer {
c.process.Process.Kill() // Ignore errors
c.process = nil
_ = c.killProcess() // Ignore errors since we're force stopping
}
c.process = nil

// Close external TCP connection if exists
if c.isExternalServer && c.conn != nil {
Expand Down Expand Up @@ -886,6 +909,8 @@ func (c *Client) handleLifecycleEvent(event SessionLifecycleEvent) {
// })
// }
func (c *Client) State() ConnectionState {
c.startStopMux.RLock()
defer c.startStopMux.RUnlock()
return c.state
}

Expand Down Expand Up @@ -1096,21 +1121,11 @@ func (c *Client) startCLIServer(ctx context.Context) error {
return fmt.Errorf("failed to start CLI server: %w", err)
}

// Monitor process exit to signal pending requests
c.processDone = make(chan struct{})
go func() {
waitErr := c.process.Wait()
if waitErr != nil {
c.processError = fmt.Errorf("CLI process exited: %v", waitErr)
} else {
c.processError = fmt.Errorf("CLI process exited unexpectedly")
}
close(c.processDone)
}()
c.monitorProcess()

// Create JSON-RPC client immediately
c.client = jsonrpc2.NewClient(stdin, stdout)
c.client.SetProcessDone(c.processDone, &c.processError)
c.client.SetProcessDone(c.processDone, c.processErrorPtr)
c.RPC = rpc.NewServerRpc(c.client)
c.setupNotificationHandler()
c.client.Start()
Expand All @@ -1127,22 +1142,28 @@ func (c *Client) startCLIServer(ctx context.Context) error {
return fmt.Errorf("failed to start CLI server: %w", err)
}

// Wait for port announcement
c.monitorProcess()

scanner := bufio.NewScanner(stdout)
timeout := time.After(10 * time.Second)
portRegex := regexp.MustCompile(`listening on port (\d+)`)

for {
select {
case <-timeout:
return fmt.Errorf("timeout waiting for CLI server to start")
killErr := c.killProcess()
return errors.Join(errors.New("timeout waiting for CLI server to start"), killErr)
case <-c.processDone:
killErr := c.killProcess()
return errors.Join(errors.New("CLI server process exited before reporting port"), killErr)
default:
if scanner.Scan() {
line := scanner.Text()
if matches := portRegex.FindStringSubmatch(line); len(matches) > 1 {
port, err := strconv.Atoi(matches[1])
if err != nil {
return fmt.Errorf("failed to parse port: %w", err)
killErr := c.killProcess()
return errors.Join(fmt.Errorf("failed to parse port: %w", err), killErr)
}
c.actualPort = port
return nil
Expand All @@ -1153,6 +1174,39 @@ func (c *Client) startCLIServer(ctx context.Context) error {
}
}

func (c *Client) killProcess() error {
if p := c.osProcess.Swap(nil); p != nil {
if err := p.Kill(); err != nil {
return fmt.Errorf("failed to kill CLI process: %w", err)
}
}
c.process = nil
return nil
}

// monitorProcess signals when the CLI process exits and captures any exit error.
// processError is intentionally a local: each process lifecycle gets its own
// error value, so goroutines from previous processes can't overwrite the
// current one. Closing the channel synchronizes with readers, guaranteeing
// they see the final processError value.
func (c *Client) monitorProcess() {
done := make(chan struct{})
c.processDone = done
proc := c.process
c.osProcess.Store(proc.Process)
var processError error
c.processErrorPtr = &processError
go func() {
waitErr := proc.Wait()
if waitErr != nil {
processError = fmt.Errorf("CLI process exited: %w", waitErr)
} else {
processError = errors.New("CLI process exited unexpectedly")
}
close(done)
}()
}

// connectToServer establishes a connection to the server.
func (c *Client) connectToServer(ctx context.Context) error {
if c.useStdio {
Expand Down Expand Up @@ -1184,6 +1238,9 @@ func (c *Client) connectViaTcp(ctx context.Context) error {

// Create JSON-RPC client with the connection
c.client = jsonrpc2.NewClient(conn, conn)
if c.processDone != nil {
c.client.SetProcessDone(c.processDone, c.processErrorPtr)
}
c.RPC = rpc.NewServerRpc(c.client)
c.setupNotificationHandler()
c.client.Start()
Expand Down
42 changes: 42 additions & 0 deletions go/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"path/filepath"
"reflect"
"regexp"
"sync"
"testing"
)

Expand Down Expand Up @@ -486,3 +487,44 @@ func TestClient_ResumeSession_RequiresPermissionHandler(t *testing.T) {
}
})
}

func TestClient_StartStopRace(t *testing.T) {
cliPath := findCLIPathForTest()
if cliPath == "" {
t.Skip("CLI not found")
}
client := NewClient(&ClientOptions{CLIPath: cliPath})
defer client.ForceStop()
errChan := make(chan error)
wg := sync.WaitGroup{}
for range 10 {
wg.Add(3)
go func() {
defer wg.Done()
if err := client.Start(t.Context()); err != nil {
select {
case errChan <- err:
default:
}
}
}()
go func() {
defer wg.Done()
if err := client.Stop(); err != nil {
select {
case errChan <- err:
default:
}
}
}()
go func() {
defer wg.Done()
client.ForceStop()
}()
}
wg.Wait()
close(errChan)
if err := <-errChan; err != nil {
t.Fatal(err)
}
}
13 changes: 7 additions & 6 deletions go/internal/jsonrpc2/jsonrpc2.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"reflect"
"sync"
"sync/atomic"
)

// Error represents a JSON-RPC error response
Expand Down Expand Up @@ -54,7 +55,7 @@ type Client struct {
mu sync.Mutex
pendingRequests map[string]chan *Response
requestHandlers map[string]RequestHandler
running bool
running atomic.Bool
stopChan chan struct{}
wg sync.WaitGroup
processDone chan struct{} // closed when the underlying process exits
Expand Down Expand Up @@ -97,17 +98,17 @@ func (c *Client) getProcessError() error {

// Start begins listening for messages in a background goroutine
func (c *Client) Start() {
c.running = true
c.running.Store(true)
c.wg.Add(1)
go c.readLoop()
}

// Stop stops the client and cleans up
func (c *Client) Stop() {
if !c.running {
if !c.running.Load() {
return
}
c.running = false
c.running.Store(false)
close(c.stopChan)

// Close stdout to unblock the readLoop
Expand Down Expand Up @@ -298,14 +299,14 @@ func (c *Client) readLoop() {

reader := bufio.NewReader(c.stdout)

for c.running {
for c.running.Load() {
// Read Content-Length header
var contentLength int
for {
line, err := reader.ReadString('\n')
if err != nil {
// Only log unexpected errors (not EOF or closed pipe during shutdown)
if err != io.EOF && c.running {
if err != io.EOF && c.running.Load() {
fmt.Printf("Error reading header: %v\n", err)
}
return
Expand Down
4 changes: 2 additions & 2 deletions go/test.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ echo

# Check prerequisites
if ! command -v go &> /dev/null; then
echo "❌ Go is not installed. Please install Go 1.21 or later."
echo "❌ Go is not installed. Please install Go 1.24 or later."
echo " Visit: https://golang.org/dl/"
exit 1
fi
Expand Down Expand Up @@ -43,7 +43,7 @@ cd "$(dirname "$0")"
echo "=== Running Go SDK E2E Tests ==="
echo

go test -v ./...
go test -v ./... -race

echo
echo "✅ All tests passed!"
Loading