/* Copyright (C) 2019 MariaDB Corporation This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; version 2 of the License. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. */ #include "SocketPool.h" #include "bytestream.h" #include "configcpp.h" #include "logger.h" #include "messageFormat.h" #include #include using namespace std; namespace { void log(logging::LOG_TYPE whichLogFile, const string& msg) { logging::Logger logger(12); // 12 = configcpp logger.logMessage(whichLogFile, msg, logging::LoggingID(12)); } } // namespace namespace idbdatafile { SocketPool::SocketPool() { config::Config* config = config::Config::makeConfig(); // the most 'config' ever put into a single line of code? string stmp; int64_t itmp = 0; try { stmp = config->getConfig("StorageManager", "MaxSockets"); itmp = strtol(stmp.c_str(), NULL, 10); if (itmp > 500 || itmp < 1) { string errmsg = "SocketPool(): Got a bad value '" + stmp + "' for StorageManager/MaxSockets. Range is 1-500."; log(logging::LOG_TYPE_CRITICAL, errmsg); throw runtime_error(errmsg); } maxSockets = itmp; } catch (exception& e) { ostringstream os; os << "SocketPool(): Using default of " << defaultSockets << "."; log(logging::LOG_TYPE_CRITICAL, os.str()); maxSockets = defaultSockets; } } SocketPool::~SocketPool() { boost::mutex::scoped_lock lock(mutex); for (uint i = 0; i < allSockets.size(); i++) ::close(allSockets[i]); } #define sm_check_error \ if (err < 0) \ { \ char _smbuf[80]; \ int l_errno = errno; \ log(logging::LOG_TYPE_ERROR, \ string("SocketPool: got a network error: ") + strerror_r(l_errno, _smbuf, 80)); \ remoteClosed(sock); \ return -1; \ } int SocketPool::send_recv(messageqcpp::ByteStream& in, messageqcpp::ByteStream* out) { messageqcpp::BSSizeType count = 0; messageqcpp::BSSizeType length = in.length(); int sock = -1; const uint8_t* inbuf = in.buf(); ssize_t err = 0; retry: int retries = 0; while (sock < 0) { sock = getSocket(); if (sock < 0) { if (++retries < 10) { // log(logging::LOG_TYPE_ERROR, "SocketPool::send_recv(): retrying in 5 sec..."); sleep(1); } else { errno = ECONNREFUSED; return -1; } } } storagemanager::sm_msg_header hdr; hdr.type = storagemanager::SM_MSG_START; hdr.payloadLen = length; hdr.flags = 0; // cout << "SP sending msg on sock " << sock << " with length = " << length << endl; err = ::write(sock, &hdr, sizeof(hdr)); if (err < 0 && errno == EPIPE) { log(logging::LOG_TYPE_WARNING, "SocketPool: remote connection is closed, getting a new one"); remoteClosed(sock); sock = -1; goto retry; } sm_check_error; while (count < length) { err = ::write(sock, &inbuf[count], length - count); sm_check_error; count += err; in.advance(err); } // cout << "SP sent msg with length = " << length << endl; out->restart(); uint8_t* outbuf; uint8_t window[8192]; uint remainingBytes = 0; uint i; storagemanager::sm_msg_header* resp = NULL; while (1) { // cout << "SP receiving msg on sock " << sock << endl; // here remainingBytes means the # of bytes from the previous message err = ::read(sock, &window[remainingBytes], 8192 - remainingBytes); sm_check_error; if (err == 0) { remoteClosed(sock); // TODO, a retry loop return -1; } uint endOfData = remainingBytes + err; // for clarity // scan for the 8-byte header. If it is fragmented, move the fragment to the front of the buffer // for the next iteration to handle. if (endOfData < storagemanager::SM_HEADER_LEN) { remainingBytes = endOfData; continue; } for (i = 0; i <= endOfData - storagemanager::SM_HEADER_LEN; i++) { if (*((uint*)&window[i]) == storagemanager::SM_MSG_START) { resp = (storagemanager::sm_msg_header*)&window[i]; break; } } if (resp == NULL) // didn't find the header yet { remainingBytes = endOfData - i; memmove(window, &window[i], remainingBytes); } else { // i == first byte of the header here // copy the payload fragment we got into the output bytestream uint startOfPayload = i + storagemanager::SM_HEADER_LEN; // for clarity out->needAtLeast(resp->payloadLen); outbuf = out->getInputPtr(); if (resp->payloadLen < endOfData - startOfPayload) cout << "SocketPool: warning! Probably got a bad length field! payload length = " << resp->payloadLen << " endOfData = " << endOfData << " startOfPayload = " << startOfPayload << endl; memcpy(outbuf, &window[startOfPayload], endOfData - startOfPayload); remainingBytes = resp->payloadLen - (endOfData - startOfPayload); // remainingBytes is now the # of bytes left to read out->advanceInputPtr(endOfData - startOfPayload); break; // done looking for the header, can fill the output buffer directly now. } } // read the rest of the payload directly into the output bytestream while (remainingBytes > 0) { err = ::read(sock, &outbuf[resp->payloadLen - remainingBytes], remainingBytes); sm_check_error; remainingBytes -= err; out->advanceInputPtr(err); } returnSocket(sock); return 0; } int SocketPool::getSocket() { boost::mutex::scoped_lock lock(mutex); int clientSocket; if (freeSockets.size() == 0 && allSockets.size() < maxSockets) { // make a new connection struct sockaddr_un addr; memset(&addr, 0, sizeof(addr)); addr.sun_family = AF_UNIX; strcpy(&addr.sun_path[1], &storagemanager::socket_name[1]); // first char is null... clientSocket = ::socket(AF_UNIX, SOCK_STREAM, 0); int err = ::connect(clientSocket, (const struct sockaddr*)&addr, sizeof(addr)); if (err >= 0) allSockets.push_back(clientSocket); else { int saved_errno = errno; ostringstream os; char buf[80]; os << "SocketPool::getSocket() failed to connect; got '" << strerror_r(saved_errno, buf, 80) << "'"; cout << os.str() << endl; log(logging::LOG_TYPE_ERROR, os.str()); close(clientSocket); errno = saved_errno; return -1; } return clientSocket; } // wait for a socket to become free while (freeSockets.size() == 0) socketAvailable.wait(lock); clientSocket = freeSockets.front(); freeSockets.pop_front(); return clientSocket; } void SocketPool::returnSocket(const int sock) { boost::mutex::scoped_lock lock(mutex); // cout << "returning socket " << sock << endl; freeSockets.push_back(sock); socketAvailable.notify_one(); } void SocketPool::remoteClosed(const int sock) { boost::mutex::scoped_lock lock(mutex); // cout << "closing socket " << sock << endl; ::close(sock); for (vector::iterator i = allSockets.begin(); i != allSockets.end(); ++i) if (*i == sock) { allSockets.erase(i, i + 1); break; } } } // namespace idbdatafile