Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/brpc/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,8 @@ int Server::StartInternal(const butil::EndPoint& endpoint,
_listen_addr = endpoint;
for (int port = port_range.min_port; port <= port_range.max_port; ++port) {
_listen_addr.port = port;
butil::fd_guard sockfd(tcp_listen(_listen_addr));
butil::fd_guard sockfd(tcp_listen(_listen_addr,
SetSocketBufferOptions));
if (sockfd < 0) {
if (port != port_range.max_port) { // not the last port, try next
continue;
Expand Down Expand Up @@ -1192,7 +1193,8 @@ int Server::StartInternal(const butil::EndPoint& endpoint,

butil::EndPoint internal_point = _listen_addr;
internal_point.port = _options.internal_port;
butil::fd_guard sockfd(tcp_listen(internal_point));
butil::fd_guard sockfd(tcp_listen(internal_point,
SetSocketBufferOptions));
if (sockfd < 0) {
LOG(ERROR) << "Fail to listen " << internal_point << " (internal)";
return -1;
Expand Down
36 changes: 20 additions & 16 deletions src/brpc/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,22 +624,6 @@ void Socket::SetSocketOptions(int fd) {
PLOG(ERROR) << "Fail to set tos of fd=" << fd << " to " << _tos;
}

if (FLAGS_socket_send_buffer_size > 0) {
int buff_size = FLAGS_socket_send_buffer_size;
if (setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &buff_size, sizeof(buff_size)) != 0) {
PLOG(ERROR) << "Fail to set sndbuf of fd=" << fd << " to "
<< buff_size;
}
}

if (FLAGS_socket_recv_buffer_size > 0) {
int buff_size = FLAGS_socket_recv_buffer_size;
if (setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &buff_size, sizeof(buff_size)) != 0) {
PLOG(ERROR) << "Fail to set rcvbuf of fd=" << fd << " to "
<< buff_size;
}
}

#if defined(OS_LINUX)
if (_tcp_user_timeout_ms > 0) {
if (setsockopt(fd, IPPROTO_TCP, TCP_USER_TIMEOUT,
Expand Down Expand Up @@ -710,6 +694,24 @@ void Socket::SetSocketOptions(int fd) {
#endif
}

void SetSocketBufferOptions(int fd) {
if (FLAGS_socket_send_buffer_size > 0) {
int buff_size = FLAGS_socket_send_buffer_size;
if (setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &buff_size, sizeof(buff_size)) != 0) {
PLOG(ERROR) << "Fail to set sndbuf of fd=" << fd << " to "
<< buff_size;
}
}

if (FLAGS_socket_recv_buffer_size > 0) {
int buff_size = FLAGS_socket_recv_buffer_size;
if (setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &buff_size, sizeof(buff_size)) != 0) {
PLOG(ERROR) << "Fail to set rcvbuf of fd=" << fd << " to "
<< buff_size;
}
}
}

// SocketId = 32-bit version + 32-bit slot.
// version: from version part of _versioned_nref, must be an EVEN number.
// slot: designated by ResourcePool.
Expand Down Expand Up @@ -1271,6 +1273,8 @@ int Socket::Connect(const timespec* abstime,
CHECK_EQ(0, butil::make_close_on_exec(sockfd));
// We need to do async connect (to manage the timeout by ourselves).
CHECK_EQ(0, butil::make_non_blocking(sockfd));
// Socket buffer sizes need to be set before connect.
brpc::SetSocketBufferOptions(sockfd);
if (!_device_name.empty()) {
#ifdef SO_BINDTODEVICE
if (setsockopt(sockfd, SOL_SOCKET, SO_BINDTODEVICE,
Expand Down
3 changes: 3 additions & 0 deletions src/brpc/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class EventDispatcher;
class Stream;
class Transport;

// Set SO_SNDBUF/SO_RCVBUF according to socket_*_buffer_size flags.
void SetSocketBufferOptions(int fd);

// A special closure for processing the about-to-recycle socket. Socket does
// not delete SocketUser, if you want, `delete this' at the end of
// BeforeRecycle().
Expand Down
10 changes: 9 additions & 1 deletion src/butil/endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ int tcp_connect(const EndPoint& server, int* self_port, int connect_timeout_ms)
return sockfd.release();
}

int tcp_listen(EndPoint point) {
int tcp_listen(EndPoint point, BeforeListenCallback before_listen) {
struct sockaddr_storage serv_addr;
socklen_t serv_addr_size = 0;
if (endpoint2sockaddr(point, &serv_addr, &serv_addr_size) != 0) {
Expand Down Expand Up @@ -602,6 +602,10 @@ int tcp_listen(EndPoint point) {
::unlink(((sockaddr_un*) &serv_addr)->sun_path);
}

if (before_listen) {
before_listen(sockfd);
}

if (::bind(sockfd, (struct sockaddr*)& serv_addr, serv_addr_size) != 0) {
return -1;
}
Expand All @@ -614,6 +618,10 @@ int tcp_listen(EndPoint point) {
return sockfd.release();
}

int tcp_listen(EndPoint point) {
return tcp_listen(point, NULL);
}

int get_local_side(int fd, EndPoint *out) {
struct sockaddr_storage addr;
socklen_t socklen = sizeof(addr);
Expand Down
3 changes: 3 additions & 0 deletions src/butil/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ int tcp_connect(const EndPoint& server, int* self_port, int connect_timeout_ms);
// To enable SO_REUSEPORT for the whole program, enable gflag -reuse_port
// Returns the socket descriptor, -1 otherwise and errno is set.
int tcp_listen(EndPoint ip_and_port);
// If `before_listen' is not NULL, it will be called before bind/listen.
typedef void (*BeforeListenCallback)(int fd);
int tcp_listen(EndPoint ip_and_port, BeforeListenCallback before_listen);

// Get the local end of a socket connection
int get_local_side(int fd, EndPoint *out);
Expand Down
136 changes: 136 additions & 0 deletions test/brpc_socket_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "brpc/policy/most_common_message.h"
#include "brpc/policy/http_rpc_protocol.h"
#include "brpc/server.h"
#include "brpc/details/server_private_accessor.h"
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "health_check.pb.h"
Expand All @@ -59,6 +60,8 @@ DECLARE_bool(socket_keepalive);
DECLARE_int32(socket_keepalive_idle_s);
DECLARE_int32(socket_keepalive_interval_s);
DECLARE_int32(socket_keepalive_count);
DECLARE_int32(socket_recv_buffer_size);
DECLARE_int32(socket_send_buffer_size);
DECLARE_int32(socket_tcp_user_timeout_ms);
}

Expand Down Expand Up @@ -1123,6 +1126,48 @@ void CheckKeepalive(int fd,
ASSERT_EQ(expected_keepalive_count, keepalive_count);
}

struct SocketBufferValues {
int recv_buffer;
int send_buffer;

SocketBufferValues() : recv_buffer(0), send_buffer(0) {}
};

void GetSocketBufferValues(int fd, SocketBufferValues* values) {
socklen_t len = sizeof(values->recv_buffer);
ASSERT_EQ(0, getsockopt(fd, SOL_SOCKET, SO_RCVBUF,
&values->recv_buffer, &len));
len = sizeof(values->send_buffer);
ASSERT_EQ(0, getsockopt(fd, SOL_SOCKET, SO_SNDBUF,
&values->send_buffer, &len));
}

void GetExpectedSocketBufferValues(int buffer_size,
SocketBufferValues* expected) {
SocketBufferValues default_values;
butil::fd_guard default_fd(socket(AF_INET, SOCK_STREAM, 0));
ASSERT_GT(default_fd, 0);
GetSocketBufferValues(default_fd, &default_values);

butil::fd_guard reference_fd(socket(AF_INET, SOCK_STREAM, 0));
ASSERT_GT(reference_fd, 0);
ASSERT_EQ(0, setsockopt(reference_fd, SOL_SOCKET, SO_RCVBUF, &buffer_size,
sizeof(buffer_size)));
ASSERT_EQ(0, setsockopt(reference_fd, SOL_SOCKET, SO_SNDBUF, &buffer_size,
sizeof(buffer_size)));
GetSocketBufferValues(reference_fd, expected);

ASSERT_NE(default_values.recv_buffer, expected->recv_buffer);
ASSERT_NE(default_values.send_buffer, expected->send_buffer);
}

void CheckSocketBufferValues(int fd, const SocketBufferValues& expected) {
SocketBufferValues actual;
GetSocketBufferValues(fd, &actual);
ASSERT_EQ(expected.recv_buffer, actual.recv_buffer);
ASSERT_EQ(expected.send_buffer, actual.send_buffer);
}

TEST_F(SocketTest, keepalive) {
int default_keepalive = 0;
int default_keepalive_idle = 0;
Expand Down Expand Up @@ -1425,6 +1470,97 @@ TEST_F(SocketTest, keepalive_input_message) {
ASSERT_EQ(EBADF, errno);
}

TEST_F(SocketTest, socket_buffer_options_before_connect) {
gflags::FlagSaver flag_saver;
const int buffer_size = 256 * 1024;
brpc::FLAGS_socket_recv_buffer_size = buffer_size;
brpc::FLAGS_socket_send_buffer_size = buffer_size;

SocketBufferValues expected;
GetExpectedSocketBufferValues(buffer_size, &expected);

butil::EndPoint point;
ASSERT_EQ(0, str2endpoint("127.0.0.1:0", &point));
butil::fd_guard listening_fd(tcp_listen(point));
ASSERT_GT(listening_fd, 0) << berror();
ASSERT_EQ(0, butil::get_local_side(listening_fd, &point));

brpc::SocketOptions options;
options.remote_side = point;
brpc::SocketId id = brpc::INVALID_SOCKET_ID;
ASSERT_EQ(0, brpc::Socket::Create(options, &id));

brpc::SocketUniquePtr ptr;
ASSERT_EQ(0, brpc::Socket::Address(id, &ptr)) << "id=" << id;

const timespec duetime = butil::milliseconds_from_now(1000);
butil::fd_guard connected_fd(ptr->Connect(&duetime, NULL, NULL));
ASSERT_GT(connected_fd, 0);
CheckSocketBufferValues(connected_fd, expected);

ASSERT_EQ(0, ptr->SetFailed());
}

TEST_F(SocketTest, socket_buffer_options_before_accept) {
gflags::FlagSaver flag_saver;
const int buffer_size = 256 * 1024;
brpc::FLAGS_socket_recv_buffer_size = buffer_size;
brpc::FLAGS_socket_send_buffer_size = buffer_size;

SocketBufferValues expected;
GetExpectedSocketBufferValues(buffer_size, &expected);

butil::EndPoint point;
ASSERT_EQ(0, str2endpoint("127.0.0.1:0", &point));
brpc::Server server;
ASSERT_EQ(0, server.Start(point, NULL));
point = server.listen_address();

brpc::Acceptor* messenger =
brpc::ServerPrivateAccessor(&server).acceptor();
ASSERT_TRUE(messenger != NULL);
ASSERT_GT(messenger->listened_fd(), 0);
CheckSocketBufferValues(messenger->listened_fd(), expected);

// Accepted sockets should inherit the listener's buffer sizes, not use
// the flags at accept time.
brpc::FLAGS_socket_recv_buffer_size = buffer_size / 2;
brpc::FLAGS_socket_send_buffer_size = buffer_size / 2;

brpc::SocketOptions options;
options.remote_side = point;
options.connect_on_create = true;
brpc::SocketId id = brpc::INVALID_SOCKET_ID;
ASSERT_EQ(0, brpc::Socket::Create(options, &id));

const int64_t start_time = butil::cpuwide_time_us();
while (messenger->ConnectionCount() < 1) {
bthread_usleep(1000);
ASSERT_LT(butil::cpuwide_time_us(), start_time + 1000000L)
<< "Too long!";
}

std::vector<brpc::SocketId> connections;
messenger->ListConnections(&connections);
ASSERT_EQ(1ul, connections.size());

{
brpc::SocketUniquePtr accepted_socket;
ASSERT_EQ(0, brpc::Socket::Address(connections[0], &accepted_socket));
ASSERT_GT(accepted_socket->fd(), 0);
CheckSocketBufferValues(accepted_socket->fd(), expected);
ASSERT_EQ(0, accepted_socket->SetFailed());
}

{
brpc::SocketUniquePtr client_socket;
ASSERT_EQ(0, brpc::Socket::Address(id, &client_socket));
ASSERT_EQ(0, client_socket->SetFailed());
}
server.Stop(0);
server.Join();
}

#if defined(OS_LINUX)
void CheckTCPUserTimeout(int fd, int expect_tcp_user_timeout) {
int tcp_user_timeout = 0;
Expand Down
Loading