Skip to content

Commit 1f35139

Browse files
authored
feat: support avc format for screencapture on android devices
1 parent 11a07cc commit 1f35139

File tree

8 files changed

+109
-40
lines changed

8 files changed

+109
-40
lines changed

cli/screenshot.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313

1414
var (
1515
screencaptureScale float64
16+
screencaptureFPS int
1617
)
1718

1819
var screenshotCmd = &cobra.Command{
@@ -57,11 +58,11 @@ var screenshotCmd = &cobra.Command{
5758
var screencaptureCmd = &cobra.Command{
5859
Use: "screencapture",
5960
Short: "Stream screen capture from a connected device",
60-
Long: `Streams MJPEG screen capture from a specified device to stdout. Only supports MJPEG format.`,
61+
Long: `Streams screen capture from a specified device to stdout. Supports MJPEG and AVC formats (Android only for AVC).`,
6162
RunE: func(cmd *cobra.Command, args []string) error {
6263
// Validate format
63-
if screencaptureFormat != "mjpeg" {
64-
response := commands.NewErrorResponse(fmt.Errorf("format must be 'mjpeg' for screen capture"))
64+
if screencaptureFormat != "mjpeg" && screencaptureFormat != "avc" {
65+
response := commands.NewErrorResponse(fmt.Errorf("format must be 'mjpeg' or 'avc' for screen capture"))
6566
printJson(response)
6667
return fmt.Errorf("%s", response.Error)
6768
}
@@ -89,14 +90,20 @@ var screencaptureCmd = &cobra.Command{
8990
// set defaults if not provided
9091
scale := screencaptureScale
9192
if scale == 0.0 {
92-
scale = devices.DefaultMJPEGScale
93+
scale = devices.DefaultScale
94+
}
95+
96+
fps := screencaptureFPS
97+
if fps == 0 {
98+
fps = devices.DefaultFramerate
9399
}
94100

95101
// Start screen capture and stream to stdout
96102
err = targetDevice.StartScreenCapture(devices.ScreenCaptureConfig{
97103
Format: screencaptureFormat,
98-
Quality: devices.DefaultMJPEGQuality,
104+
Quality: devices.DefaultQuality,
99105
Scale: scale,
106+
FPS: fps,
100107
OnProgress: func(message string) {
101108
utils.Verbose(message)
102109
},
@@ -134,4 +141,5 @@ func init() {
134141
screencaptureCmd.Flags().StringVar(&deviceId, "device", "", "ID of the device to capture from")
135142
screencaptureCmd.Flags().StringVarP(&screencaptureFormat, "format", "f", "mjpeg", "Output format for screen capture")
136143
screencaptureCmd.Flags().Float64Var(&screencaptureScale, "scale", 0, "Scale factor for screen capture (0 for default)")
144+
screencaptureCmd.Flags().IntVar(&screencaptureFPS, "fps", 0, "Frames per second for screen capture (0 for default)")
137145
}

devices/android.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -784,12 +784,12 @@ func (d *AndroidDevice) GetAppPath(packageName string) (string, error) {
784784
}
785785

786786
func (d *AndroidDevice) StartScreenCapture(config ScreenCaptureConfig) error {
787-
if config.Format != "mjpeg" {
788-
return fmt.Errorf("unsupported format: %s, only 'mjpeg' is supported", config.Format)
787+
if config.Format != "mjpeg" && config.Format != "avc" {
788+
return fmt.Errorf("unsupported format: %s, only 'mjpeg' and 'avc' are supported", config.Format)
789789
}
790790

791791
if config.OnProgress != nil {
792-
config.OnProgress("Installing DeviceKit")
792+
config.OnProgress("Installing Agent")
793793
}
794794

795795
utils.Verbose("Ensuring DeviceKit is installed...")
@@ -803,12 +803,19 @@ func (d *AndroidDevice) StartScreenCapture(config ScreenCaptureConfig) error {
803803
return fmt.Errorf("failed to get app path: %v", err)
804804
}
805805

806+
var serverClass string
807+
if config.Format == "mjpeg" {
808+
serverClass = "com.mobilenext.devicekit.MjpegServer"
809+
} else {
810+
serverClass = "com.mobilenext.devicekit.AvcServer"
811+
}
812+
806813
if config.OnProgress != nil {
807-
config.OnProgress("Starting MJPEG server")
814+
config.OnProgress("Starting Agent")
808815
}
809816

810-
utils.Verbose("Starting MJPEG server with app path: %s", appPath)
811-
cmdArgs := append([]string{"-s", d.getAdbIdentifier()}, "exec-out", fmt.Sprintf("CLASSPATH=%s", appPath), "app_process", "/system/bin", "com.mobilenext.devicekit.MjpegServer", "--quality", fmt.Sprintf("%d", config.Quality), "--scale", fmt.Sprintf("%.2f", config.Scale))
817+
utils.Verbose("Starting %s with app path: %s", serverClass, appPath)
818+
cmdArgs := append([]string{"-s", d.getAdbIdentifier()}, "exec-out", fmt.Sprintf("CLASSPATH=%s", appPath), "app_process", "/system/bin", serverClass, "--quality", fmt.Sprintf("%d", config.Quality), "--scale", fmt.Sprintf("%.2f", config.Scale), "--fps", fmt.Sprintf("%d", config.FPS))
812819
utils.Verbose("Running command: %s %s", getAdbPath(), strings.Join(cmdArgs, " "))
813820
cmd := exec.Command(getAdbPath(), cmdArgs...)
814821

@@ -818,7 +825,7 @@ func (d *AndroidDevice) StartScreenCapture(config ScreenCaptureConfig) error {
818825
}
819826

820827
if err := cmd.Start(); err != nil {
821-
return fmt.Errorf("failed to start MJPEG server: %v", err)
828+
return fmt.Errorf("failed to start %s: %v", serverClass, err)
822829
}
823830

824831
// Read bytes from the command output and send to callback

devices/common.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@ import (
1010
)
1111

1212
const (
13-
// Default MJPEG streaming quality (1-100)
14-
DefaultMJPEGQuality = 80
15-
// Default MJPEG streaming scale (0.1-1.0)
16-
DefaultMJPEGScale = 1.0
17-
// Default MJPEG streaming framerate (frames per second)
18-
DefaultMJPEGFramerate = 30
13+
// default streaming quality (1-100)
14+
DefaultQuality = 80
15+
// default streaming scale (0.1-1.0)
16+
DefaultScale = 1.0
17+
// default streaming framerate (frames per second)
18+
DefaultFramerate = 30
1919
)
2020

2121
// ScreenCaptureConfig contains configuration for screen capture operations
2222
type ScreenCaptureConfig struct {
2323
Format string
2424
Quality int
2525
Scale float64
26+
FPS int
2627
OnProgress func(message string) // optional progress callback
2728
OnData func([]byte) bool // data callback - return false to stop
2829
}
@@ -92,13 +93,13 @@ func GetAllControllableDevices(includeOffline bool) ([]ControllableDevice, error
9293
offlineAndroidCount := 0
9394
offlineAndroidDuration := int64(0)
9495
if includeOffline {
95-
startOfflineAndroid := time.Now()
9696
// build map of online device IDs for quick lookup
9797
onlineDeviceIDs := make(map[string]bool)
9898
for _, device := range androidDevices {
9999
onlineDeviceIDs[device.ID()] = true
100100
}
101101

102+
startOfflineAndroid := time.Now()
102103
offlineEmulators, err := getOfflineAndroidEmulators(onlineDeviceIDs)
103104
offlineAndroidDuration = time.Since(startOfflineAndroid).Milliseconds()
104105
if err != nil {
@@ -183,6 +184,7 @@ type FullDeviceInfo struct {
183184

184185
// GetDeviceInfoList returns a list of DeviceInfo for all connected devices
185186
func GetDeviceInfoList(opts DeviceListOptions) ([]DeviceInfo, error) {
187+
startTime := time.Now()
186188
devices, err := GetAllControllableDevices(opts.IncludeOffline)
187189
if err != nil {
188190
return nil, fmt.Errorf("error getting devices: %w", err)
@@ -216,6 +218,7 @@ func GetDeviceInfoList(opts DeviceListOptions) ([]DeviceInfo, error) {
216218
State: state,
217219
})
218220
}
221+
utils.Verbose("GetDeviceInfoList took %s", time.Since(startTime))
219222

220223
return deviceInfoList, nil
221224
}

devices/ios.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,9 @@ func (d *IOSDevice) StartAgent(config StartAgentConfig) error {
277277
}
278278

279279
if webdriverBundleId == "" {
280+
if config.OnProgress != nil {
281+
config.OnProgress("Installing WebDriverAgent")
282+
}
280283
return fmt.Errorf("WebDriverAgent is not installed")
281284
}
282285

@@ -637,7 +640,11 @@ func (d IOSDevice) Info() (*FullDeviceInfo, error) {
637640

638641
func (d IOSDevice) StartScreenCapture(config ScreenCaptureConfig) error {
639642
// configure mjpeg framerate
640-
err := d.wdaClient.SetMjpegFramerate(DefaultMJPEGFramerate)
643+
fps := config.FPS
644+
if fps == 0 {
645+
fps = DefaultFramerate
646+
}
647+
err := d.wdaClient.SetMjpegFramerate(fps)
641648
if err != nil {
642649
return err
643650
}

devices/simulator.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,10 @@ func (s SimulatorDevice) InstallWebDriverAgent(onProgress func(string)) error {
353353
defer func() { _ = os.Remove(file) }()
354354
}
355355

356+
if onProgress != nil {
357+
onProgress("Installing WebDriverAgent")
358+
}
359+
356360
dir, err := utils.Unzip(file)
357361
if err != nil {
358362
return fmt.Errorf("failed to unzip WebDriverAgent: %w", err)
@@ -697,7 +701,11 @@ func (s *SimulatorDevice) StartScreenCapture(config ScreenCaptureConfig) error {
697701
}
698702

699703
// configure mjpeg framerate
700-
err = s.wdaClient.SetMjpegFramerate(DefaultMJPEGFramerate)
704+
fps := config.FPS
705+
if fps == 0 {
706+
fps = DefaultFramerate
707+
}
708+
err = s.wdaClient.SetMjpegFramerate(fps)
701709
if err != nil {
702710
return err
703711
}

devices/wda/requests.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ func (c *WdaClient) isSessionStillValid(sessionId string) bool {
185185

186186
// GetOrCreateSession returns cached session or creates a new one
187187
func (c *WdaClient) GetOrCreateSession() (string, error) {
188+
c.mu.Lock()
189+
defer c.mu.Unlock()
190+
188191
// if we have a cached session, validate it first
189192
if c.sessionId != "" {
190193
if c.isSessionStillValid(c.sessionId) {
@@ -212,9 +215,11 @@ func (c *WdaClient) DeleteSession(sessionId string) error {
212215
return fmt.Errorf("failed to delete session %s: %w", sessionId, err)
213216
}
214217

218+
c.mu.Lock()
215219
if c.sessionId == sessionId {
216220
c.sessionId = ""
217221
}
222+
c.mu.Unlock()
218223

219224
return nil
220225
}

devices/wda/types.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ package wda
33
import (
44
"net/http"
55
"strings"
6+
"sync"
67
"time"
78
)
89

910
type WdaClient struct {
1011
baseURL string
1112
httpClient *http.Client
1213
sessionId string
14+
mu sync.Mutex
1315
}
1416

1517
func NewWdaClient(hostPort string) *WdaClient {

server/server.go

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,17 @@ func sendBanner(w http.ResponseWriter, r *http.Request) {
678678
_ = json.NewEncoder(w).Encode(okResponse)
679679
}
680680

681+
// newJsonRpcNotification creates a JSON-RPC notification message
682+
func newJsonRpcNotification(message string) map[string]interface{} {
683+
return map[string]interface{}{
684+
"jsonrpc": "2.0",
685+
"method": "notification/message",
686+
"params": map[string]string{
687+
"message": message,
688+
},
689+
}
690+
}
691+
681692
func handleScreenCapture(w http.ResponseWriter, params json.RawMessage) error {
682693

683694
_ = http.NewResponseController(w).SetWriteDeadline(time.Now().Add(10 * time.Minute))
@@ -693,41 +704,59 @@ func handleScreenCapture(w http.ResponseWriter, params json.RawMessage) error {
693704
return fmt.Errorf("error finding device: %w", err)
694705
}
695706

696-
if screenCaptureParams.Format == "" || screenCaptureParams.Format != "mjpeg" {
697-
return fmt.Errorf("format must be 'mjpeg' for screen capture")
707+
// Set default format if not provided
708+
if screenCaptureParams.Format == "" {
709+
screenCaptureParams.Format = "mjpeg"
710+
}
711+
712+
// Validate format
713+
if screenCaptureParams.Format != "mjpeg" && screenCaptureParams.Format != "avc" {
714+
return fmt.Errorf("format must be 'mjpeg' or 'avc' for screen capture")
715+
}
716+
717+
// AVC format is only supported on Android
718+
if screenCaptureParams.Format == "avc" && targetDevice.Platform() != "android" {
719+
return fmt.Errorf("avc format is only supported on Android devices")
698720
}
699721

700722
// Set defaults if not provided
701723
quality := screenCaptureParams.Quality
702724
if quality == 0 {
703-
quality = devices.DefaultMJPEGQuality
725+
quality = devices.DefaultQuality
704726
}
705727

706728
scale := screenCaptureParams.Scale
707729
if scale == 0.0 {
708-
scale = devices.DefaultMJPEGScale
730+
scale = devices.DefaultScale
709731
}
710732

711-
// Set headers for streaming response
712-
w.Header().Set("Content-Type", "multipart/x-mixed-replace; boundary=BoundaryString")
733+
// Set headers for streaming response based on format
734+
if screenCaptureParams.Format == "mjpeg" {
735+
w.Header().Set("Content-Type", "multipart/x-mixed-replace; boundary=BoundaryString")
736+
} else {
737+
// avc format
738+
w.Header().Set("Content-Type", "video/h264")
739+
}
713740
w.Header().Set("Cache-Control", "no-cache")
714741
w.Header().Set("Connection", "keep-alive")
715742
w.Header().Set("Transfer-Encoding", "chunked")
716743

717744
// progress callback sends JSON-RPC notifications through the MJPEG stream
718-
progressCallback := func(message string) {
719-
notification := map[string]interface{}{
720-
"jsonrpc": "2.0",
721-
"method": "notification/message",
722-
"params": map[string]string{
723-
"message": message,
724-
},
725-
}
726-
statusJSON, _ := json.Marshal(notification)
727-
mimeMessage := fmt.Sprintf("--BoundaryString\r\nContent-Type: application/json\r\nContent-Length: %d\r\n\r\n%s\r\n", len(statusJSON), statusJSON)
728-
_, _ = w.Write([]byte(mimeMessage))
729-
if flusher, ok := w.(http.Flusher); ok {
730-
flusher.Flush()
745+
// only used for MJPEG format, not for AVC
746+
var progressCallback func(string)
747+
if screenCaptureParams.Format == "mjpeg" {
748+
progressCallback = func(message string) {
749+
notification := newJsonRpcNotification(message)
750+
statusJSON, err := json.Marshal(notification)
751+
if err != nil {
752+
log.Printf("Failed to marshal progress message: %v", err)
753+
return
754+
}
755+
mimeMessage := fmt.Sprintf("--BoundaryString\r\nContent-Type: application/json\r\nContent-Length: %d\r\n\r\n%s\r\n", len(statusJSON), statusJSON)
756+
_, _ = w.Write([]byte(mimeMessage))
757+
if flusher, ok := w.(http.Flusher); ok {
758+
flusher.Flush()
759+
}
731760
}
732761
}
733762

0 commit comments

Comments
 (0)