diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index ca511c4f388a..d8079daf1bde 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -207,7 +207,7 @@ static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr)
 	struct vsock_sock *vsk;
 
 	list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table)
-		if (vsock_addr_equals_addr_any(addr, &vsk->local_addr))
+		if (addr->svm_port == vsk->local_addr.svm_port)
 			return sk_vsock(vsk);
 
 	return NULL;
@@ -220,8 +220,8 @@ static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src,
 
 	list_for_each_entry(vsk, vsock_connected_sockets(src, dst),
 			    connected_table) {
-		if (vsock_addr_equals_addr(src, &vsk->remote_addr)
-		    && vsock_addr_equals_addr(dst, &vsk->local_addr)) {
+		if (vsock_addr_equals_addr(src, &vsk->remote_addr) &&
+		    dst->svm_port == vsk->local_addr.svm_port) {
 			return sk_vsock(vsk);
 		}
 	}
diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c
index a70ace83a153..1f6508e249ae 100644
--- a/net/vmw_vsock/vmci_transport.c
+++ b/net/vmw_vsock/vmci_transport.c
@@ -464,19 +464,16 @@ static struct sock *vmci_transport_get_pending(
 	struct vsock_sock *vlistener;
 	struct vsock_sock *vpending;
 	struct sock *pending;
+	struct sockaddr_vm src;
+
+	vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port);
 
 	vlistener = vsock_sk(listener);
 
 	list_for_each_entry(vpending, &vlistener->pending_links,
 			    pending_links) {
-		struct sockaddr_vm src;
-		struct sockaddr_vm dst;
-
-		vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port);
-		vsock_addr_init(&dst, pkt->dg.dst.context, pkt->dst_port);
-
 		if (vsock_addr_equals_addr(&src, &vpending->remote_addr) &&
-		    vsock_addr_equals_addr(&dst, &vpending->local_addr)) {
+		    pkt->dst_port == vpending->local_addr.svm_port) {
 			pending = sk_vsock(vpending);
 			sock_hold(pending);
 			goto found;
@@ -739,10 +736,15 @@ static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg)
 	 */
 	bh_lock_sock(sk);
 
-	if (!sock_owned_by_user(sk) && sk->sk_state == SS_CONNECTED)
-		vmci_trans(vsk)->notify_ops->handle_notify_pkt(
-				sk, pkt, true, &dst, &src,
-				&bh_process_pkt);
+	if (!sock_owned_by_user(sk)) {
+		/* The local context ID may be out of date, update it. */
+		vsk->local_addr.svm_cid = dst.svm_cid;
+
+		if (sk->sk_state == SS_CONNECTED)
+			vmci_trans(vsk)->notify_ops->handle_notify_pkt(
+					sk, pkt, true, &dst, &src,
+					&bh_process_pkt);
+	}
 
 	bh_unlock_sock(sk);
 
@@ -902,6 +904,9 @@ static void vmci_transport_recv_pkt_work(struct work_struct *work)
 
 	lock_sock(sk);
 
+	/* The local context ID may be out of date. */
+	vsock_sk(sk)->local_addr.svm_cid = pkt->dg.dst.context;
+
 	switch (sk->sk_state) {
 	case SS_LISTEN:
 		vmci_transport_recv_listen(sk, pkt);
@@ -958,6 +963,10 @@ static int vmci_transport_recv_listen(struct sock *sk,
 	pending = vmci_transport_get_pending(sk, pkt);
 	if (pending) {
 		lock_sock(pending);
+
+		/* The local context ID may be out of date. */
+		vsock_sk(pending)->local_addr.svm_cid = pkt->dg.dst.context;
+
 		switch (pending->sk_state) {
 		case SS_CONNECTING:
 			err = vmci_transport_recv_connecting_server(sk,
diff --git a/net/vmw_vsock/vsock_addr.c b/net/vmw_vsock/vsock_addr.c
index b7df1aea7c59..ec2611b4ea0e 100644
--- a/net/vmw_vsock/vsock_addr.c
+++ b/net/vmw_vsock/vsock_addr.c
@@ -64,16 +64,6 @@ bool vsock_addr_equals_addr(const struct sockaddr_vm *addr,
 }
 EXPORT_SYMBOL_GPL(vsock_addr_equals_addr);
 
-bool vsock_addr_equals_addr_any(const struct sockaddr_vm *addr,
-				const struct sockaddr_vm *other)
-{
-	return (addr->svm_cid == VMADDR_CID_ANY ||
-		other->svm_cid == VMADDR_CID_ANY ||
-		addr->svm_cid == other->svm_cid) &&
-	       addr->svm_port == other->svm_port;
-}
-EXPORT_SYMBOL_GPL(vsock_addr_equals_addr_any);
-
 int vsock_addr_cast(const struct sockaddr *addr,
 		    size_t len, struct sockaddr_vm **out_addr)
 {
diff --git a/net/vmw_vsock/vsock_addr.h b/net/vmw_vsock/vsock_addr.h
index cdfbcefdf843..9ccd5316eac0 100644
--- a/net/vmw_vsock/vsock_addr.h
+++ b/net/vmw_vsock/vsock_addr.h
@@ -24,8 +24,6 @@ bool vsock_addr_bound(const struct sockaddr_vm *addr);
 void vsock_addr_unbind(struct sockaddr_vm *addr);
 bool vsock_addr_equals_addr(const struct sockaddr_vm *addr,
 			    const struct sockaddr_vm *other);
-bool vsock_addr_equals_addr_any(const struct sockaddr_vm *addr,
-				const struct sockaddr_vm *other);
 int vsock_addr_cast(const struct sockaddr *addr, size_t len,
 		    struct sockaddr_vm **out_addr);