adb: Check sender's socket id when receiving packets.

handle_packet() in adb.c didn't check that when an A_WRTE packet is
received, the sender's local-id matches the socket's peer id.

This meant that a compromised adbd server could sent packets to
the host adb server, spoofing the identity of another connected
device if it could "guess" the right host socket id.

This patch gets rid of the issue by enforcing even more checks
to ensure that all packets comply with the description in
protocol.txt.

+ Fix a bug where closing a local socket associated with a
  remote one would always send an A_CLSE(0, remote-id, "")
  message, though protocol.txt says that should only happen
  for failed opens.

  The issue was that local_socket_close() called
  remote_socket_close() after clearing the remote socket's
  'peer' field.

  The fix introduces a new asocket optional callback,
  named 'shutdown' that is called before that, and is
  used to send the A_CLSE() message with the right ID
  in remote_socket_shutdown().

  Also add some code in handle_packet() to detect
  invalid close commands.

Change-Id: I9098bc8c6e81f8809334b060e5dca4fc92e6fbc9
diff --git a/adb/adb.c b/adb/adb.c
index 72b7484..41270f9 100644
--- a/adb/adb.c
+++ b/adb/adb.c
@@ -562,7 +562,7 @@
         break;
 
     case A_OPEN: /* OPEN(local-id, 0, "destination") */
-        if (t->online) {
+        if (t->online && p->msg.arg0 != 0 && p->msg.arg1 == 0) {
             char *name = (char*) p->data;
             name[p->msg.data_length > 0 ? p->msg.data_length - 1 : 0] = 0;
             s = create_local_service_socket(name);
@@ -578,28 +578,50 @@
         break;
 
     case A_OKAY: /* READY(local-id, remote-id, "") */
-        if (t->online) {
-            if((s = find_local_socket(p->msg.arg1))) {
+        if (t->online && p->msg.arg0 != 0 && p->msg.arg1 != 0) {
+            if((s = find_local_socket(p->msg.arg1, 0))) {
                 if(s->peer == 0) {
+                    /* On first READY message, create the connection. */
                     s->peer = create_remote_socket(p->msg.arg0, t);
                     s->peer->peer = s;
+                    s->ready(s);
+                } else if (s->peer->id == p->msg.arg0) {
+                    /* Other READY messages must use the same local-id */
+                    s->ready(s);
+                } else {
+                    D("Invalid A_OKAY(%d,%d), expected A_OKAY(%d,%d) on transport %s\n",
+                      p->msg.arg0, p->msg.arg1, s->peer->id, p->msg.arg1, t->serial);
                 }
-                s->ready(s);
             }
         }
         break;
 
-    case A_CLSE: /* CLOSE(local-id, remote-id, "") */
-        if (t->online) {
-            if((s = find_local_socket(p->msg.arg1))) {
-                s->close(s);
+    case A_CLSE: /* CLOSE(local-id, remote-id, "") or CLOSE(0, remote-id, "") */
+        if (t->online && p->msg.arg1 != 0) {
+            if((s = find_local_socket(p->msg.arg1, p->msg.arg0))) {
+                /* According to protocol.txt, p->msg.arg0 might be 0 to indicate
+                 * a failed OPEN only. However, due to a bug in previous ADB
+                 * versions, CLOSE(0, remote-id, "") was also used for normal
+                 * CLOSE() operations.
+                 *
+                 * This is bad because it means a compromised adbd could
+                 * send packets to close connections between the host and
+                 * other devices. To avoid this, only allow this if the local
+                 * socket has a peer on the same transport.
+                 */
+                if (p->msg.arg0 == 0 && s->peer && s->peer->transport != t) {
+                    D("Invalid A_CLSE(0, %u) from transport %s, expected transport %s\n",
+                      p->msg.arg1, t->serial, s->peer->transport->serial);
+                } else {
+                    s->close(s);
+                }
             }
         }
         break;
 
-    case A_WRTE:
-        if (t->online) {
-            if((s = find_local_socket(p->msg.arg1))) {
+    case A_WRTE: /* WRITE(local-id, remote-id, <data>) */
+        if (t->online && p->msg.arg0 != 0 && p->msg.arg1 != 0) {
+            if((s = find_local_socket(p->msg.arg1, p->msg.arg0))) {
                 unsigned rid = p->msg.arg0;
                 p->len = p->msg.data_length;
 
diff --git a/adb/adb.h b/adb/adb.h
index 622ca70..c85b02a 100644
--- a/adb/adb.h
+++ b/adb/adb.h
@@ -122,6 +122,12 @@
         */
     void (*ready)(asocket *s);
 
+        /* shutdown is called by the peer before it goes away.
+        ** the socket should not do any further calls on its peer.
+        ** Always followed by a call to close. Optional, i.e. can be NULL.
+        */
+    void (*shutdown)(asocket *s);
+
         /* close is called by the peer when it has gone away.
         ** we are not allowed to make any further calls on the
         ** peer once our close method is called.
@@ -233,7 +239,7 @@
 
 void print_packet(const char *label, apacket *p);
 
-asocket *find_local_socket(unsigned id);
+asocket *find_local_socket(unsigned local_id, unsigned remote_id);
 void install_local_socket(asocket *s);
 void remove_socket(asocket *s);
 void close_all_sockets(atransport *t);
diff --git a/adb/sockets.c b/adb/sockets.c
index f17608b..7f54ad9 100644
--- a/adb/sockets.c
+++ b/adb/sockets.c
@@ -59,17 +59,22 @@
     .prev = &local_socket_closing_list,
 };
 
-asocket *find_local_socket(unsigned id)
+// Parse the global list of sockets to find one with id |local_id|.
+// If |peer_id| is not 0, also check that it is connected to a peer
+// with id |peer_id|. Returns an asocket handle on success, NULL on failure.
+asocket *find_local_socket(unsigned local_id, unsigned peer_id)
 {
     asocket *s;
     asocket *result = NULL;
 
     adb_mutex_lock(&socket_list_lock);
     for (s = local_socket_list.next; s != &local_socket_list; s = s->next) {
-        if (s->id == id) {
+        if (s->id != local_id)
+            continue;
+        if (peer_id == 0 || (s->peer && s->peer->id == peer_id)) {
             result = s;
-            break;
         }
+        break;
     }
     adb_mutex_unlock(&socket_list_lock);
 
@@ -91,6 +96,11 @@
     adb_mutex_lock(&socket_list_lock);
 
     s->id = local_socket_next_id++;
+
+    // Socket ids should never be 0.
+    if (local_socket_next_id == 0)
+      local_socket_next_id = 1;
+
     insert_local_socket(s, &local_socket_list);
 
     adb_mutex_unlock(&socket_list_lock);
@@ -230,6 +240,12 @@
     if(s->peer) {
         D("LS(%d): closing peer. peer->id=%d peer->fd=%d\n",
           s->id, s->peer->id, s->peer->fd);
+        /* Note: it's important to call shutdown before disconnecting from
+         * the peer, this ensures that remote sockets can still get the id
+         * of the local socket they're connected to, to send a CLOSE()
+         * protocol event. */
+        if (s->peer->shutdown)
+          s->peer->shutdown(s->peer);
         s->peer->peer = 0;
         // tweak to avoid deadlock
         if (s->peer->close == local_socket_close) {
@@ -397,6 +413,7 @@
     s->fd = fd;
     s->enqueue = local_socket_enqueue;
     s->ready = local_socket_ready;
+    s->shutdown = NULL;
     s->close = local_socket_close;
     install_local_socket(s);
 
@@ -485,21 +502,29 @@
     send_packet(p, s->transport);
 }
 
-static void remote_socket_close(asocket *s)
+static void remote_socket_shutdown(asocket *s)
 {
-    D("entered remote_socket_close RS(%d) CLOSE fd=%d peer->fd=%d\n",
+    D("entered remote_socket_shutdown RS(%d) CLOSE fd=%d peer->fd=%d\n",
       s->id, s->fd, s->peer?s->peer->fd:-1);
     apacket *p = get_apacket();
     p->msg.command = A_CLSE;
     if(s->peer) {
         p->msg.arg0 = s->peer->id;
+    }
+    p->msg.arg1 = s->id;
+    send_packet(p, s->transport);
+}
+
+static void remote_socket_close(asocket *s)
+{
+    if (s->peer) {
         s->peer->peer = 0;
         D("RS(%d) peer->close()ing peer->id=%d peer->fd=%d\n",
           s->id, s->peer->id, s->peer->fd);
         s->peer->close(s->peer);
     }
-    p->msg.arg1 = s->id;
-    send_packet(p, s->transport);
+    D("entered remote_socket_close RS(%d) CLOSE fd=%d peer->fd=%d\n",
+      s->id, s->fd, s->peer?s->peer->fd:-1);
     D("RS(%d): closed\n", s->id);
     remove_transport_disconnect( s->transport, &((aremotesocket*)s)->disconnect );
     free(s);
@@ -519,15 +544,24 @@
     free(s);
 }
 
+/* Create an asocket to exchange packets with a remote service through transport
+  |t|. Where |id| is the socket id of the corresponding service on the other
+   side of the transport (it is allocated by the remote side and _cannot_ be 0).
+   Returns a new non-NULL asocket handle. */
 asocket *create_remote_socket(unsigned id, atransport *t)
 {
-    asocket *s = calloc(1, sizeof(aremotesocket));
-    adisconnect*  dis = &((aremotesocket*)s)->disconnect;
+    asocket* s;
+    adisconnect* dis;
+
+    if (id == 0) fatal("invalid remote socket id (0)");
+    s = calloc(1, sizeof(aremotesocket));
+    dis = &((aremotesocket*)s)->disconnect;
 
     if (s == NULL) fatal("cannot allocate socket");
     s->id = id;
     s->enqueue = remote_socket_enqueue;
     s->ready = remote_socket_ready;
+    s->shutdown = remote_socket_shutdown;
     s->close = remote_socket_close;
     s->transport = t;
 
@@ -562,6 +596,7 @@
 static void local_socket_ready_notify(asocket *s)
 {
     s->ready = local_socket_ready;
+    s->shutdown = NULL;
     s->close = local_socket_close;
     adb_write(s->fd, "OKAY", 4);
     s->ready(s);
@@ -573,6 +608,7 @@
 static void local_socket_close_notify(asocket *s)
 {
     s->ready = local_socket_ready;
+    s->shutdown = NULL;
     s->close = local_socket_close;
     sendfailmsg(s->fd, "closed");
     s->close(s);
@@ -767,6 +803,7 @@
         adb_write(s->peer->fd, "OKAY", 4);
 
         s->peer->ready = local_socket_ready;
+        s->peer->shutdown = NULL;
         s->peer->close = local_socket_close;
         s->peer->peer = s2;
         s2->peer = s->peer;
@@ -806,6 +843,7 @@
         ** tear down
         */
     s->peer->ready = local_socket_ready_notify;
+    s->peer->shutdown = NULL;
     s->peer->close = local_socket_close_notify;
     s->peer->peer = 0;
         /* give him our transport and upref it */
@@ -851,6 +889,7 @@
     if (s == NULL) fatal("cannot allocate socket");
     s->enqueue = smart_socket_enqueue;
     s->ready = smart_socket_ready;
+    s->shutdown = NULL;
     s->close = smart_socket_close;
 
     D("SS(%d)\n", s->id);