From ddf96fa712f5192070e7df7746bcb31d79db712b Mon Sep 17 00:00:00 2001 From: Greg Kroah-Hartman Date: Mon, 6 Nov 2023 14:49:18 +0000 Subject: [PATCH] Revert "tcp: allow again tcp_disconnect() when threads are waiting" This reverts commit ec9bc89a018842006d63f6545c50768e79bd89f8 which is commit 419ce133ab928ab5efd7b50b2ef36ddfd4eadbd2 upstream. It breaks the android ABI and if this is needed in the future, can be brought back in an abi-safe way. Bug: 161946584 Change-Id: I591c4ae39181ebf38284aaeb927e890a08380e2b Signed-off-by: Greg Kroah-Hartman --- .../chelsio/inline_crypto/chtls/chtls_io.c | 36 ++++--------------- include/net/sock.h | 10 +++--- net/core/stream.c | 12 +++---- net/ipv4/af_inet.c | 10 ++---- net/ipv4/inet_connection_sock.c | 1 + net/ipv4/tcp.c | 16 ++++----- net/ipv4/tcp_bpf.c | 4 --- net/mptcp/protocol.c | 7 ++++ net/tls/tls_main.c | 10 ++---- net/tls/tls_sw.c | 19 ++++------ 10 files changed, 45 insertions(+), 80 deletions(-) diff --git a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c index 5e45bef4fd34..a4256087ac82 100644 --- a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c +++ b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c @@ -911,7 +911,7 @@ static int csk_wait_memory(struct chtls_dev *cdev, struct sock *sk, long *timeo_p) { DEFINE_WAIT_FUNC(wait, woken_wake_function); - int ret, err = 0; + int err = 0; long current_timeo; long vm_wait = 0; bool noblock; @@ -942,13 +942,10 @@ static int csk_wait_memory(struct chtls_dev *cdev, set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); sk->sk_write_pending++; - ret = sk_wait_event(sk, ¤t_timeo, sk->sk_err || - (sk->sk_shutdown & SEND_SHUTDOWN) || - (csk_mem_free(cdev, sk) && !vm_wait), - &wait); + sk_wait_event(sk, ¤t_timeo, sk->sk_err || + (sk->sk_shutdown & SEND_SHUTDOWN) || + (csk_mem_free(cdev, sk) && !vm_wait), &wait); sk->sk_write_pending--; - if (ret < 0) - goto do_error; if (vm_wait) { vm_wait -= current_timeo; @@ -1441,7 +1438,6 @@ static int chtls_pt_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int copied = 0; int target; long timeo; - int ret; buffers_freed = 0; @@ -1517,11 +1513,7 @@ static int chtls_pt_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, if (copied >= target) break; chtls_cleanup_rbuf(sk, copied); - ret = sk_wait_data(sk, &timeo, NULL); - if (ret < 0) { - copied = copied ? : ret; - goto unlock; - } + sk_wait_data(sk, &timeo, NULL); continue; found_ok_skb: if (!skb->len) { @@ -1616,8 +1608,6 @@ skip_copy: if (buffers_freed) chtls_cleanup_rbuf(sk, copied); - -unlock: release_sock(sk); return copied; } @@ -1634,7 +1624,6 @@ static int peekmsg(struct sock *sk, struct msghdr *msg, int copied = 0; size_t avail; /* amount of available data in current skb */ long timeo; - int ret; lock_sock(sk); timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); @@ -1686,12 +1675,7 @@ static int peekmsg(struct sock *sk, struct msghdr *msg, release_sock(sk); lock_sock(sk); } else { - ret = sk_wait_data(sk, &timeo, NULL); - if (ret < 0) { - /* here 'copied' is 0 due to previous checks */ - copied = ret; - break; - } + sk_wait_data(sk, &timeo, NULL); } if (unlikely(peek_seq != tp->copied_seq)) { @@ -1762,7 +1746,6 @@ int chtls_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int copied = 0; long timeo; int target; /* Read at least this many bytes */ - int ret; buffers_freed = 0; @@ -1854,11 +1837,7 @@ int chtls_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, if (copied >= target) break; chtls_cleanup_rbuf(sk, copied); - ret = sk_wait_data(sk, &timeo, NULL); - if (ret < 0) { - copied = copied ? : ret; - goto unlock; - } + sk_wait_data(sk, &timeo, NULL); continue; found_ok_skb: @@ -1927,7 +1906,6 @@ skip_copy: if (buffers_freed) chtls_cleanup_rbuf(sk, copied); -unlock: release_sock(sk); return copied; } diff --git a/include/net/sock.h b/include/net/sock.h index c6328af832ed..20ff5bf6c0f2 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -335,7 +335,7 @@ struct sk_filter; * @sk_cgrp_data: cgroup data for this cgroup * @sk_memcg: this socket's memory cgroup association * @sk_write_pending: a write to stream socket waits to start - * @sk_disconnects: number of disconnect operations performed on this sock + * @sk_wait_pending: number of threads blocked on this socket * @sk_state_change: callback to indicate change in the state of the sock * @sk_data_ready: callback to indicate there is data to be processed * @sk_write_space: callback to indicate there is bf sending space available @@ -428,7 +428,7 @@ struct sock { unsigned int sk_napi_id; #endif int sk_rcvbuf; - int sk_disconnects; + int sk_wait_pending; struct sk_filter __rcu *sk_filter; union { @@ -1197,7 +1197,8 @@ static inline void sock_rps_reset_rxhash(struct sock *sk) } #define sk_wait_event(__sk, __timeo, __condition, __wait) \ - ({ int __rc, __dis = __sk->sk_disconnects; \ + ({ int __rc; \ + __sk->sk_wait_pending++; \ release_sock(__sk); \ __rc = __condition; \ if (!__rc) { \ @@ -1207,7 +1208,8 @@ static inline void sock_rps_reset_rxhash(struct sock *sk) } \ sched_annotate_sleep(); \ lock_sock(__sk); \ - __rc = __dis == __sk->sk_disconnects ? __condition : -EPIPE; \ + __sk->sk_wait_pending--; \ + __rc = __condition; \ __rc; \ }) diff --git a/net/core/stream.c b/net/core/stream.c index 051aa71a8ad0..5b05b889d31a 100644 --- a/net/core/stream.c +++ b/net/core/stream.c @@ -117,7 +117,7 @@ EXPORT_SYMBOL(sk_stream_wait_close); */ int sk_stream_wait_memory(struct sock *sk, long *timeo_p) { - int ret, err = 0; + int err = 0; long vm_wait = 0; long current_timeo = *timeo_p; DEFINE_WAIT_FUNC(wait, woken_wake_function); @@ -142,13 +142,11 @@ int sk_stream_wait_memory(struct sock *sk, long *timeo_p) set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); sk->sk_write_pending++; - ret = sk_wait_event(sk, ¤t_timeo, READ_ONCE(sk->sk_err) || - (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) || - (sk_stream_memory_free(sk) && !vm_wait), - &wait); + sk_wait_event(sk, ¤t_timeo, READ_ONCE(sk->sk_err) || + (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) || + (sk_stream_memory_free(sk) && + !vm_wait), &wait); sk->sk_write_pending--; - if (ret < 0) - goto do_error; if (vm_wait) { vm_wait -= current_timeo; diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 2aedb9d46687..ebb737ac9e89 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -589,6 +589,7 @@ static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias) add_wait_queue(sk_sleep(sk), &wait); sk->sk_write_pending += writebias; + sk->sk_wait_pending++; /* Basic assumption: if someone sets sk->sk_err, he _must_ * change state of the socket from TCP_SYN_*. @@ -604,6 +605,7 @@ static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias) } remove_wait_queue(sk_sleep(sk), &wait); sk->sk_write_pending -= writebias; + sk->sk_wait_pending--; return timeo; } @@ -632,7 +634,6 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, return -EINVAL; if (uaddr->sa_family == AF_UNSPEC) { - sk->sk_disconnects++; err = sk->sk_prot->disconnect(sk, flags); sock->state = err ? SS_DISCONNECTING : SS_UNCONNECTED; goto out; @@ -687,7 +688,6 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, int writebias = (sk->sk_protocol == IPPROTO_TCP) && tcp_sk(sk)->fastopen_req && tcp_sk(sk)->fastopen_req->data ? 1 : 0; - int dis = sk->sk_disconnects; /* Error code is set above */ if (!timeo || !inet_wait_for_connect(sk, timeo, writebias)) @@ -696,11 +696,6 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, err = sock_intr_errno(timeo); if (signal_pending(current)) goto out; - - if (dis != sk->sk_disconnects) { - err = -EPIPE; - goto out; - } } /* Connection was closed by RST, timeout, ICMP error @@ -722,7 +717,6 @@ out: sock_error: err = sock_error(sk) ? : -ECONNABORTED; sock->state = SS_UNCONNECTED; - sk->sk_disconnects++; if (sk->sk_prot->disconnect(sk, flags)) sock->state = SS_DISCONNECTING; goto out; diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index 80ce0112e24b..62a3b103f258 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -1143,6 +1143,7 @@ struct sock *inet_csk_clone_lock(const struct sock *sk, if (newsk) { struct inet_connection_sock *newicsk = inet_csk(newsk); + newsk->sk_wait_pending = 0; inet_sk_set_state(newsk, TCP_SYN_RECV); newicsk->icsk_bind_hash = NULL; newicsk->icsk_bind2_hash = NULL; diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 775ace7bb7ce..96d248f07c81 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -827,9 +827,7 @@ ssize_t tcp_splice_read(struct socket *sock, loff_t *ppos, */ if (!skb_queue_empty(&sk->sk_receive_queue)) break; - ret = sk_wait_data(sk, &timeo, NULL); - if (ret < 0) - break; + sk_wait_data(sk, &timeo, NULL); if (signal_pending(current)) { ret = sock_intr_errno(timeo); break; @@ -2551,11 +2549,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, __sk_flush_backlog(sk); } else { tcp_cleanup_rbuf(sk, copied); - err = sk_wait_data(sk, &timeo, last); - if (err < 0) { - err = copied ? : err; - goto out; - } + sk_wait_data(sk, &timeo, last); } if ((flags & MSG_PEEK) && @@ -3079,6 +3073,12 @@ int tcp_disconnect(struct sock *sk, int flags) int old_state = sk->sk_state; u32 seq; + /* Deny disconnect if other threads are blocked in sk_wait_event() + * or inet_wait_for_connect(). + */ + if (sk->sk_wait_pending) + return -EBUSY; + if (old_state != TCP_CLOSE) tcp_set_state(sk, TCP_CLOSE); diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index cb4549db8bcf..f53380fd89bc 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -302,8 +302,6 @@ msg_bytes_ready: } data = tcp_msg_wait_data(sk, psock, timeo); - if (data < 0) - return data; if (data && !sk_psock_queue_empty(psock)) goto msg_bytes_ready; copied = -EAGAIN; @@ -348,8 +346,6 @@ msg_bytes_ready: timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); data = tcp_msg_wait_data(sk, psock, timeo); - if (data < 0) - return data; if (data) { if (!sk_psock_queue_empty(psock)) goto msg_bytes_ready; diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index 0eb20274459c..cb108dfc6dfd 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -3117,6 +3117,12 @@ static int mptcp_disconnect(struct sock *sk, int flags) { struct mptcp_sock *msk = mptcp_sk(sk); + /* Deny disconnect if other threads are blocked in sk_wait_event() + * or inet_wait_for_connect(). + */ + if (sk->sk_wait_pending) + return -EBUSY; + /* We are on the fastopen error path. We can't call straight into the * subflows cleanup code due to lock nesting (we are already under * msk->firstsocket lock). @@ -3184,6 +3190,7 @@ struct sock *mptcp_sk_clone_init(const struct sock *sk, inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk); #endif + nsk->sk_wait_pending = 0; __mptcp_init_sock(nsk); msk = mptcp_sk(nsk); diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 338a443fa47b..f2e7302a4d96 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -96,8 +96,8 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx) int wait_on_pending_writer(struct sock *sk, long *timeo) { + int rc = 0; DEFINE_WAIT_FUNC(wait, woken_wake_function); - int ret, rc = 0; add_wait_queue(sk_sleep(sk), &wait); while (1) { @@ -111,13 +111,9 @@ int wait_on_pending_writer(struct sock *sk, long *timeo) break; } - ret = sk_wait_event(sk, timeo, - !READ_ONCE(sk->sk_write_pending), &wait); - if (ret) { - if (ret < 0) - rc = ret; + if (sk_wait_event(sk, timeo, + !READ_ONCE(sk->sk_write_pending), &wait)) break; - } } remove_wait_queue(sk_sleep(sk), &wait); return rc; diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 2af72d349192..c5c8fdadc05e 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -1296,7 +1296,6 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); DEFINE_WAIT_FUNC(wait, woken_wake_function); - int ret = 0; long timeo; timeo = sock_rcvtimeo(sk, nonblock); @@ -1308,9 +1307,6 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, if (sk->sk_err) return sock_error(sk); - if (ret < 0) - return ret; - if (!skb_queue_empty(&sk->sk_receive_queue)) { tls_strp_check_rcv(&ctx->strp); if (tls_strp_msg_ready(ctx)) @@ -1329,10 +1325,10 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, released = true; add_wait_queue(sk_sleep(sk), &wait); sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); - ret = sk_wait_event(sk, &timeo, - tls_strp_msg_ready(ctx) || - !sk_psock_queue_empty(psock), - &wait); + sk_wait_event(sk, &timeo, + tls_strp_msg_ready(ctx) || + !sk_psock_queue_empty(psock), + &wait); sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); remove_wait_queue(sk_sleep(sk), &wait); @@ -1859,7 +1855,6 @@ static int tls_rx_reader_acquire(struct sock *sk, struct tls_sw_context_rx *ctx, bool nonblock) { long timeo; - int ret; timeo = sock_rcvtimeo(sk, nonblock); @@ -1869,16 +1864,14 @@ static int tls_rx_reader_acquire(struct sock *sk, struct tls_sw_context_rx *ctx, ctx->reader_contended = 1; add_wait_queue(&ctx->wq, &wait); - ret = sk_wait_event(sk, &timeo, - !READ_ONCE(ctx->reader_present), &wait); + sk_wait_event(sk, &timeo, + !READ_ONCE(ctx->reader_present), &wait); remove_wait_queue(&ctx->wq, &wait); if (timeo <= 0) return -EAGAIN; if (signal_pending(current)) return sock_intr_errno(timeo); - if (ret < 0) - return ret; } WRITE_ONCE(ctx->reader_present, 1);