You've already forked mariadb-columnstore-engine
							
							
				mirror of
				https://github.com/mariadb-corporation/mariadb-columnstore-engine.git
				synced 2025-10-31 18:30:33 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			277 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			277 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /* Copyright (C) 2019 MariaDB Corporation
 | |
| 
 | |
|    This program is free software; you can redistribute it and/or
 | |
|    modify it under the terms of the GNU General Public License
 | |
|    as published by the Free Software Foundation; version 2 of
 | |
|    the License.
 | |
| 
 | |
|    This program is distributed in the hope that it will be useful,
 | |
|    but WITHOUT ANY WARRANTY; without even the implied warranty of
 | |
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | |
|    GNU General Public License for more details.
 | |
| 
 | |
|    You should have received a copy of the GNU General Public License
 | |
|    along with this program; if not, write to the Free Software
 | |
|    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 | |
|    MA 02110-1301, USA. */
 | |
| 
 | |
| #include "SocketPool.h"
 | |
| #include "bytestream.h"
 | |
| #include "configcpp.h"
 | |
| #include "logger.h"
 | |
| #include "messageFormat.h"
 | |
| 
 | |
| #include <sys/socket.h>
 | |
| #include <sys/un.h>
 | |
| 
 | |
| using namespace std;
 | |
| 
 | |
| namespace
 | |
| {
 | |
| void log(logging::LOG_TYPE whichLogFile, const string& msg)
 | |
| {
 | |
|   logging::Logger logger(12);  // 12 = configcpp
 | |
|   logger.logMessage(whichLogFile, msg, logging::LoggingID(12));
 | |
| }
 | |
| 
 | |
| }  // namespace
 | |
| 
 | |
| namespace idbdatafile
 | |
| {
 | |
| SocketPool::SocketPool()
 | |
| {
 | |
|   config::Config* config =
 | |
|       config::Config::makeConfig();  // the most 'config' ever put into a single line of code?
 | |
|   string stmp;
 | |
|   int64_t itmp = 0;
 | |
| 
 | |
|   try
 | |
|   {
 | |
|     stmp = config->getConfig("StorageManager", "MaxSockets");
 | |
|     itmp = strtol(stmp.c_str(), NULL, 10);
 | |
|     if (itmp > 500 || itmp < 1)
 | |
|     {
 | |
|       string errmsg =
 | |
|           "SocketPool(): Got a bad value '" + stmp + "' for StorageManager/MaxSockets.  Range is 1-500.";
 | |
|       log(logging::LOG_TYPE_CRITICAL, errmsg);
 | |
|       throw runtime_error(errmsg);
 | |
|     }
 | |
|     maxSockets = itmp;
 | |
|   }
 | |
|   catch (exception& e)
 | |
|   {
 | |
|     ostringstream os;
 | |
|     os << "SocketPool(): Using default of " << defaultSockets << ".";
 | |
|     log(logging::LOG_TYPE_CRITICAL, os.str());
 | |
|     maxSockets = defaultSockets;
 | |
|   }
 | |
| }
 | |
| 
 | |
| SocketPool::~SocketPool()
 | |
| {
 | |
|   boost::mutex::scoped_lock lock(mutex);
 | |
| 
 | |
|   for (uint i = 0; i < allSockets.size(); i++)
 | |
|     ::close(allSockets[i]);
 | |
| }
 | |
| 
 | |
| #define sm_check_error                                                                  \
 | |
|   if (err < 0)                                                                          \
 | |
|   {                                                                                     \
 | |
|     char _smbuf[80];                                                                    \
 | |
|     int l_errno = errno;                                                                \
 | |
|     log(logging::LOG_TYPE_ERROR,                                                        \
 | |
|         string("SocketPool: got a network error: ") + strerror_r(l_errno, _smbuf, 80)); \
 | |
|     remoteClosed(sock);                                                                 \
 | |
|     return -1;                                                                          \
 | |
|   }
 | |
| 
 | |
| int SocketPool::send_recv(messageqcpp::ByteStream& in, messageqcpp::ByteStream* out)
 | |
| {
 | |
|   messageqcpp::BSSizeType count = 0;
 | |
|   messageqcpp::BSSizeType length = in.length();
 | |
|   int sock = -1;
 | |
|   const uint8_t* inbuf = in.buf();
 | |
|   ssize_t err = 0;
 | |
| 
 | |
| retry:
 | |
|   int retries = 0;
 | |
|   while (sock < 0)
 | |
|   {
 | |
|     sock = getSocket();
 | |
|     if (sock < 0)
 | |
|     {
 | |
|       if (++retries < 10)
 | |
|       {
 | |
|         // log(logging::LOG_TYPE_ERROR, "SocketPool::send_recv(): retrying in 5 sec...");
 | |
|         sleep(1);
 | |
|       }
 | |
|       else
 | |
|       {
 | |
|         errno = ECONNREFUSED;
 | |
|         return -1;
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   storagemanager::sm_msg_header hdr;
 | |
|   hdr.type = storagemanager::SM_MSG_START;
 | |
|   hdr.payloadLen = length;
 | |
|   hdr.flags = 0;
 | |
|   // cout << "SP sending msg on sock " << sock << " with length = " << length << endl;
 | |
|   err = ::write(sock, &hdr, sizeof(hdr));
 | |
|   if (err < 0 && errno == EPIPE)
 | |
|   {
 | |
|     log(logging::LOG_TYPE_WARNING, "SocketPool: remote connection is closed, getting a new one");
 | |
|     remoteClosed(sock);
 | |
|     sock = -1;
 | |
|     goto retry;
 | |
|   }
 | |
|   sm_check_error;
 | |
|   while (count < length)
 | |
|   {
 | |
|     err = ::write(sock, &inbuf[count], length - count);
 | |
|     sm_check_error;
 | |
|     count += err;
 | |
|     in.advance(err);
 | |
|   }
 | |
|   // cout << "SP sent msg with length = " << length << endl;
 | |
| 
 | |
|   out->restart();
 | |
|   uint8_t* outbuf;
 | |
|   uint8_t window[8192];
 | |
|   uint remainingBytes = 0;
 | |
|   uint i;
 | |
|   storagemanager::sm_msg_header* resp = NULL;
 | |
| 
 | |
|   while (1)
 | |
|   {
 | |
|     // cout << "SP receiving msg on sock " << sock << endl;
 | |
|     // here remainingBytes means the # of bytes from the previous message
 | |
|     err = ::read(sock, &window[remainingBytes], 8192 - remainingBytes);
 | |
|     sm_check_error;
 | |
|     if (err == 0)
 | |
|     {
 | |
|       remoteClosed(sock);
 | |
|       // TODO, a retry loop
 | |
|       return -1;
 | |
|     }
 | |
|     uint endOfData = remainingBytes + err;  // for clarity
 | |
| 
 | |
|     // scan for the 8-byte header.  If it is fragmented, move the fragment to the front of the buffer
 | |
|     // for the next iteration to handle.
 | |
| 
 | |
|     if (endOfData < storagemanager::SM_HEADER_LEN)
 | |
|     {
 | |
|       remainingBytes = endOfData;
 | |
|       continue;
 | |
|     }
 | |
| 
 | |
|     for (i = 0; i <= endOfData - storagemanager::SM_HEADER_LEN; i++)
 | |
|     {
 | |
|       if (*((uint*)&window[i]) == storagemanager::SM_MSG_START)
 | |
|       {
 | |
|         resp = (storagemanager::sm_msg_header*)&window[i];
 | |
|         break;
 | |
|       }
 | |
|     }
 | |
| 
 | |
|     if (resp == NULL)  // didn't find the header yet
 | |
|     {
 | |
|       remainingBytes = endOfData - i;
 | |
|       memmove(window, &window[i], remainingBytes);
 | |
|     }
 | |
|     else
 | |
|     {
 | |
|       // i == first byte of the header here
 | |
|       // copy the payload fragment we got into the output bytestream
 | |
|       uint startOfPayload = i + storagemanager::SM_HEADER_LEN;  // for clarity
 | |
|       out->needAtLeast(resp->payloadLen);
 | |
|       outbuf = out->getInputPtr();
 | |
|       if (resp->payloadLen < endOfData - startOfPayload)
 | |
|         cout << "SocketPool: warning!  Probably got a bad length field!  payload length = "
 | |
|              << resp->payloadLen << " endOfData = " << endOfData << " startOfPayload = " << startOfPayload
 | |
|              << endl;
 | |
|       memcpy(outbuf, &window[startOfPayload], endOfData - startOfPayload);
 | |
|       remainingBytes = resp->payloadLen -
 | |
|                        (endOfData - startOfPayload);  // remainingBytes is now the # of bytes left to read
 | |
|       out->advanceInputPtr(endOfData - startOfPayload);
 | |
|       break;  // done looking for the header, can fill the output buffer directly now.
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   // read the rest of the payload directly into the output bytestream
 | |
|   while (remainingBytes > 0)
 | |
|   {
 | |
|     err = ::read(sock, &outbuf[resp->payloadLen - remainingBytes], remainingBytes);
 | |
|     sm_check_error;
 | |
|     remainingBytes -= err;
 | |
|     out->advanceInputPtr(err);
 | |
|   }
 | |
|   returnSocket(sock);
 | |
|   return 0;
 | |
| }
 | |
| 
 | |
| int SocketPool::getSocket()
 | |
| {
 | |
|   boost::mutex::scoped_lock lock(mutex);
 | |
|   int clientSocket;
 | |
| 
 | |
|   if (freeSockets.size() == 0 && allSockets.size() < maxSockets)
 | |
|   {
 | |
|     // make a new connection
 | |
|     struct sockaddr_un addr;
 | |
|     memset(&addr, 0, sizeof(addr));
 | |
|     addr.sun_family = AF_UNIX;
 | |
|     strcpy(&addr.sun_path[1], &storagemanager::socket_name[1]);  // first char is null...
 | |
|     clientSocket = ::socket(AF_UNIX, SOCK_STREAM, 0);
 | |
|     int err = ::connect(clientSocket, (const struct sockaddr*)&addr, sizeof(addr));
 | |
|     if (err >= 0)
 | |
|       allSockets.push_back(clientSocket);
 | |
|     else
 | |
|     {
 | |
|       int saved_errno = errno;
 | |
|       ostringstream os;
 | |
|       char buf[80];
 | |
|       os << "SocketPool::getSocket() failed to connect; got '" << strerror_r(saved_errno, buf, 80) << "'";
 | |
|       cout << os.str() << endl;
 | |
|       log(logging::LOG_TYPE_ERROR, os.str());
 | |
|       close(clientSocket);
 | |
|       errno = saved_errno;
 | |
|       return -1;
 | |
|     }
 | |
|     return clientSocket;
 | |
|   }
 | |
| 
 | |
|   // wait for a socket to become free
 | |
|   while (freeSockets.size() == 0)
 | |
|     socketAvailable.wait(lock);
 | |
| 
 | |
|   clientSocket = freeSockets.front();
 | |
|   freeSockets.pop_front();
 | |
|   return clientSocket;
 | |
| }
 | |
| 
 | |
| void SocketPool::returnSocket(const int sock)
 | |
| {
 | |
|   boost::mutex::scoped_lock lock(mutex);
 | |
|   // cout << "returning socket " << sock << endl;
 | |
|   freeSockets.push_back(sock);
 | |
|   socketAvailable.notify_one();
 | |
| }
 | |
| 
 | |
| void SocketPool::remoteClosed(const int sock)
 | |
| {
 | |
|   boost::mutex::scoped_lock lock(mutex);
 | |
|   // cout << "closing socket " << sock << endl;
 | |
|   ::close(sock);
 | |
|   for (vector<int>::iterator i = allSockets.begin(); i != allSockets.end(); ++i)
 | |
|     if (*i == sock)
 | |
|     {
 | |
|       allSockets.erase(i, i + 1);
 | |
|       break;
 | |
|     }
 | |
| }
 | |
| 
 | |
| }  // namespace idbdatafile
 |