From 8ef4e09aa496ed452899e44b097ad25c791e2435 Mon Sep 17 00:00:00 2001 From: Patrick LeBlanc Date: Mon, 4 Feb 2019 14:44:35 -0600 Subject: [PATCH] Changed ownership of the socket. --- src/AppendTask.cpp | 19 ++++++++++--------- src/AppendTask.h | 2 +- src/CopyTask.cpp | 15 ++++++++------- src/CopyTask.h | 2 +- src/ListDirectoryTask.cpp | 28 +++++++++++----------------- src/ListDirectoryTask.h | 2 +- src/OpenTask.cpp | 9 +++++---- src/OpenTask.h | 2 +- src/PingTask.cpp | 9 +++++---- src/PingTask.h | 2 +- src/PosixTask.cpp | 16 +--------------- src/PosixTask.h | 6 ++---- src/ProcessTask.cpp | 12 +++++++++++- src/ReadTask.cpp | 13 +++++++------ src/ReadTask.h | 2 +- src/StatTask.cpp | 15 ++++++++------- src/StatTask.h | 2 +- src/TruncateTask.cpp | 15 ++++++++------- src/TruncateTask.h | 2 +- src/UnlinkTask.cpp | 15 ++++++++------- src/UnlinkTask.h | 2 +- src/WriteTask.cpp | 21 +++++++++++---------- src/WriteTask.h | 2 +- 23 files changed, 105 insertions(+), 108 deletions(-) diff --git a/src/AppendTask.cpp b/src/AppendTask.cpp index 92df13ff1..0461e838a 100644 --- a/src/AppendTask.cpp +++ b/src/AppendTask.cpp @@ -16,32 +16,32 @@ AppendTask::~AppendTask() { } -#define check_error(msg) \ +#define check_error(msg, ret) \ if (!success) \ { \ handleError(msg, errno); \ - return; \ + return ret; \ } #define min(x, y) (x < y ? x : y) -void AppendTask::run() +bool AppendTask::run() { bool success; uint8_t cmdbuf[1024] = {0}; int err; success = read(cmdbuf, sizeof(append_cmd)); - check_error("AppendTask read"); + check_error("AppendTask read", false); append_cmd *cmd = (append_cmd *) cmdbuf; if (cmd->flen > 1023 - sizeof(*cmd)) { handleError("AppendTask", ENAMETOOLONG); - return; + return true; } success = read(&cmdbuf[sizeof(*cmd)], cmd->flen); - check_error("AppendTask read"); + check_error("AppendTask read", false); size_t readCount = 0, writeCount = 0; vector databuf; @@ -52,7 +52,7 @@ void AppendTask::run() { uint toRead = min(cmd->count - readCount, bufsize); success = read(&databuf[0], toRead); - check_error("AppendTask read data"); + check_error("AppendTask read data", false); readCount += toRead; uint writePos = 0; while (writeCount < readCount) @@ -75,14 +75,15 @@ void AppendTask::run() resp->payloadLen = 8; resp->returnCode = -1; *((int *) &resp[1]) = errno; - write((uint8_t *) respbuf, sizeof(sm_msg_resp) + 4); + success = write((uint8_t *) respbuf, sizeof(sm_msg_resp) + 4); } else { resp->payloadLen = 4; resp->returnCode = writeCount; - write((uint8_t *) respbuf, sizeof(sm_msg_resp)); + success = write((uint8_t *) respbuf, sizeof(sm_msg_resp)); } + return success; } } diff --git a/src/AppendTask.h b/src/AppendTask.h index 6183801b0..b02e5c49b 100644 --- a/src/AppendTask.h +++ b/src/AppendTask.h @@ -14,7 +14,7 @@ class AppendTask : public PosixTask AppendTask(int sock, uint length); virtual ~AppendTask(); - void run(); + bool run(); private: AppendTask(); diff --git a/src/CopyTask.cpp b/src/CopyTask.cpp index a97f60c8b..bb35b1474 100644 --- a/src/CopyTask.cpp +++ b/src/CopyTask.cpp @@ -16,14 +16,14 @@ CopyTask::~CopyTask() { } -#define check_error(msg) \ +#define check_error(msg, ret) \ if (!success) \ { \ handleError(msg, errno); \ - return; \ + return ret; \ } -void CopyTask::run() +bool CopyTask::run() { bool success; uint8_t buf[2048] = {0}; @@ -31,11 +31,11 @@ void CopyTask::run() if (getLength() > 2047) { handleError("CopyTask read", ENAMETOOLONG); - return; + return true; } success = read(buf, getLength()); - check_error("CopyTask read"); + check_error("CopyTask read", false); copy_cmd *cmd = (copy_cmd *) buf; string filename1(cmd->file1.filename, cmd->file1.flen); // need to copy this in case it's not null terminated f_name *filename2 = (f_name *) &buf[sizeof(copy_cmd) + cmd->file1.flen]; @@ -44,14 +44,15 @@ void CopyTask::run() if (err) { handleError("CopyTask copy", errno); - return; + return true; } sm_msg_resp *resp = (sm_msg_resp *) buf; resp->type = SM_MSG_START; resp->payloadLen = 4; resp->returnCode = 0; - write(buf, sizeof(sm_msg_resp)); + success = write(buf, sizeof(sm_msg_resp)); + return success; } } diff --git a/src/CopyTask.h b/src/CopyTask.h index 80ba927e9..a938d9e4c 100644 --- a/src/CopyTask.h +++ b/src/CopyTask.h @@ -13,7 +13,7 @@ class CopyTask : public PosixTask CopyTask(int sock, uint length); virtual ~CopyTask(); - void run(); + bool run(); private: CopyTask(); diff --git a/src/ListDirectoryTask.cpp b/src/ListDirectoryTask.cpp index c83bfd2b2..bf19c7b8a 100644 --- a/src/ListDirectoryTask.cpp +++ b/src/ListDirectoryTask.cpp @@ -18,18 +18,11 @@ ListDirectoryTask::~ListDirectoryTask() { } -#define check_error(msg) \ +#define check_error(msg, ret) \ if (!success) \ { \ handleError(msg, errno); \ - return; \ - } - -#define check_error_2(msg) \ - if (!success) \ - { \ - handleError(msg, errno); \ - return false; \ + return ret; \ } #define min(x, y) (x < y ? x : y) @@ -40,7 +33,7 @@ bool ListDirectoryTask::writeString(uint8_t *buf, int *offset, int size, const s if (size - *offset < 4) // eh, let's not frag 4 bytes. { success = write(buf, *offset); - check_error_2("ListDirectoryTask::writeString()"); + check_error("ListDirectoryTask::writeString()", false); *offset = 0; } uint count = 0, len = str.length(); @@ -55,14 +48,14 @@ bool ListDirectoryTask::writeString(uint8_t *buf, int *offset, int size, const s if (*offset == size) { success = write(buf, *offset); - check_error_2("ListDirectoryTask::writeString()"); + check_error("ListDirectoryTask::writeString()", false); *offset = 0; } } return true; } -void ListDirectoryTask::run() +bool ListDirectoryTask::run() { bool success; uint8_t buf[1024] = {0}; @@ -70,11 +63,11 @@ void ListDirectoryTask::run() if (getLength() > 1023) { handleError("ListDirectoryTask read", ENAMETOOLONG); - return; + return true; } success = read(buf, getLength()); - check_error("ListDirectoryTask read"); + check_error("ListDirectoryTask read", false); listdir_cmd *cmd = (listdir_cmd *) buf; vector listing; @@ -82,7 +75,7 @@ void ListDirectoryTask::run() if (err) { handleError("ListDirectory", errno); - return; + return true; } // be careful modifying the listdir return types... @@ -101,14 +94,15 @@ void ListDirectoryTask::run() for (uint i = 0; i < listing.size(); i++) { success = writeString(buf, &offset, 1024, listing[i]); - check_error("ListDirectoryTask write"); + check_error("ListDirectoryTask write", false); } if (offset != 0) { success = write(buf, offset); - check_error("ListDirectoryTask write"); + check_error("ListDirectoryTask write", false); } + return true; } } diff --git a/src/ListDirectoryTask.h b/src/ListDirectoryTask.h index 42544bfe2..410ccffdc 100644 --- a/src/ListDirectoryTask.h +++ b/src/ListDirectoryTask.h @@ -14,7 +14,7 @@ class ListDirectoryTask : public PosixTask ListDirectoryTask(int sock, uint length); virtual ~ListDirectoryTask(); - void run(); + bool run(); private: ListDirectoryTask(); diff --git a/src/OpenTask.cpp b/src/OpenTask.cpp index 3329dc7e6..d36f32196 100644 --- a/src/OpenTask.cpp +++ b/src/OpenTask.cpp @@ -19,7 +19,7 @@ OpenTask::~OpenTask() { } -void OpenTask::run() +bool OpenTask::run() { /* get the parameters @@ -32,14 +32,14 @@ void OpenTask::run() if (getLength() > 1023) { handleError("OpenTask read1", ENAMETOOLONG); - return; + return true; } success = read(buf, getLength()); if (!success) { handleError("OpenTask read2", errno); - return; + return false; } open_cmd *cmd = (open_cmd *) buf; @@ -48,7 +48,7 @@ void OpenTask::run() if (err) { handleError("OpenTask open", errno); - return; + return true; } sm_msg_resp *resp = (sm_msg_resp *) buf; @@ -58,6 +58,7 @@ void OpenTask::run() success = write(buf, sizeof(struct stat) + sizeof(sm_msg_resp)); if (!success) handleError("OpenTask write", errno); + return success; } } diff --git a/src/OpenTask.h b/src/OpenTask.h index 0bba356a6..134c42aa5 100644 --- a/src/OpenTask.h +++ b/src/OpenTask.h @@ -13,7 +13,7 @@ class OpenTask : public PosixTask OpenTask(int sock, uint length); virtual ~OpenTask(); - void run(); + bool run(); private: OpenTask(); diff --git a/src/PingTask.cpp b/src/PingTask.cpp index 2b11e5420..d7eefc4d0 100644 --- a/src/PingTask.cpp +++ b/src/PingTask.cpp @@ -14,7 +14,7 @@ PingTask::~PingTask() { } -void PingTask::run() +bool PingTask::run() { // not much to check on for Milestone 1 @@ -23,14 +23,14 @@ void PingTask::run() if (getLength() > 1) { handleError("PingTask", E2BIG); - return; + return true; } // consume the msg bool success = read(&buf, getLength()); if (!success) { handleError("PingTask", errno); - return; + return false; } // send generic success response @@ -38,7 +38,8 @@ void PingTask::run() ret.type = SM_MSG_START; ret.payloadLen = 4; ret.returnCode = 0; - write((uint8_t *) &ret, sizeof(ret)); + success = write((uint8_t *) &ret, sizeof(ret)); + return success; } } diff --git a/src/PingTask.h b/src/PingTask.h index 24f836ad4..20030fb05 100644 --- a/src/PingTask.h +++ b/src/PingTask.h @@ -14,7 +14,7 @@ class PingTask : public PosixTask PingTask(int sock, uint length); virtual ~PingTask(); - void run(); + bool run(); private: PingTask(); diff --git a/src/PosixTask.cpp b/src/PosixTask.cpp index b5479c16e..013d726f6 100644 --- a/src/PosixTask.cpp +++ b/src/PosixTask.cpp @@ -19,8 +19,7 @@ PosixTask::PosixTask(int _sock, uint _length) : remainingLengthInStream(_length), remainingLengthForCaller(_length), bufferPos(0), - bufferLen(0), - socketReturned(false) + bufferLen(0) { ioc = IOCoordinator::get(); } @@ -28,8 +27,6 @@ PosixTask::PosixTask(int _sock, uint _length) : PosixTask::~PosixTask() { consumeMsg(); - if (!socketReturned) - returnSocket(); } void PosixTask::handleError(const char *name, int errCode) @@ -46,17 +43,6 @@ void PosixTask::handleError(const char *name, int errCode) // TODO: construct and log a message cout << name << " caught an error: " << strerror_r(errCode, buf, 80) << endl; - socketError(); -} - -void PosixTask::returnSocket() -{ - socketReturned = true; -} - -void PosixTask::socketError() -{ - socketReturned = true; } uint PosixTask::getRemainingLength() diff --git a/src/PosixTask.h b/src/PosixTask.h index b406968ef..13ad3a901 100644 --- a/src/PosixTask.h +++ b/src/PosixTask.h @@ -17,7 +17,8 @@ class PosixTask PosixTask(int sock, uint length); virtual ~PosixTask(); - virtual void run() = 0; + // this should return false if there was a network error, true otherwise including for other errors + virtual bool run() = 0; void primeBuffer(); protected: @@ -28,8 +29,6 @@ class PosixTask uint getLength(); // returns the total length of the msg uint getRemainingLength(); // returns the remaining length from the caller's perspective void handleError(const char *name, int errCode); - void returnSocket(); - void socketError(); IOCoordinator *ioc; @@ -44,7 +43,6 @@ class PosixTask uint8_t localBuffer[bufferSize]; uint bufferPos; uint bufferLen; - bool socketReturned; }; } diff --git a/src/ProcessTask.cpp b/src/ProcessTask.cpp index f146bf0ee..825e16f63 100644 --- a/src/ProcessTask.cpp +++ b/src/ProcessTask.cpp @@ -98,7 +98,17 @@ void ProcessTask::operator()() throw runtime_error("ProcessTask: got an unknown opcode"); } task->primeBuffer(); - task->run(); + bool success = task->run(); + if (!success) + { + //SessionManager::get()->socketError(sock); + //returnedSock = true; + } + else + { + SessionManager::get()->returnSocket(sock); + returnedSock = true; + } } diff --git a/src/ReadTask.cpp b/src/ReadTask.cpp index 3f86aa9e2..aa62b4b87 100644 --- a/src/ReadTask.cpp +++ b/src/ReadTask.cpp @@ -16,26 +16,26 @@ ReadTask::~ReadTask() { } -#define check_error(msg) \ +#define check_error(msg, ret) \ if (!success) \ { \ handleError(msg, errno); \ - return; \ + return ret; \ } -void ReadTask::run() +bool ReadTask::run() { uint8_t buf[1024] = {0}; // get the parameters if (getLength() > 1023) { handleError("ReadTask read", EFAULT); - return; + return true; } bool success; success = read(buf, getLength()); - check_error("ReadTask read cmd"); + check_error("ReadTask read cmd", false); cmd_overlay *cmd = (cmd_overlay *) buf; // read from IOC, write to the socket @@ -67,7 +67,8 @@ void ReadTask::run() count += err; } - write(outbuf); + success = write(outbuf); + return success; } diff --git a/src/ReadTask.h b/src/ReadTask.h index d388c42cc..1f659f12c 100644 --- a/src/ReadTask.h +++ b/src/ReadTask.h @@ -14,7 +14,7 @@ class ReadTask : public PosixTask ReadTask(int sock, uint length); virtual ~ReadTask(); - void run(); + bool run(); private: ReadTask(); diff --git a/src/StatTask.cpp b/src/StatTask.cpp index 40114afc5..45054a518 100644 --- a/src/StatTask.cpp +++ b/src/StatTask.cpp @@ -20,25 +20,25 @@ StatTask::~StatTask() { } -#define check_error(msg) \ +#define check_error(msg, ret) \ if (!success) \ { \ handleError(msg, errno); \ - return; \ + return ret; \ } -void StatTask::run() +bool StatTask::run() { bool success; uint8_t buf[1024] = {0}; if (getLength() > 1023) { handleError("StatTask read", ENAMETOOLONG); - return; + return true; } success = read(buf, getLength()); - check_error("StatTask read"); + check_error("StatTask read", false); stat_cmd *cmd = (stat_cmd *) buf; sm_msg_resp *resp = (sm_msg_resp *) buf; @@ -46,13 +46,14 @@ void StatTask::run() if (err) { handleError("StatTask stat", errno); - return; + return true; } resp->type = SM_MSG_START; resp->payloadLen = sizeof(struct stat) + 4; resp->returnCode = 0; - write(buf, sizeof(*resp) + sizeof(struct stat)); + success = write(buf, sizeof(*resp) + sizeof(struct stat)); + return success; } } diff --git a/src/StatTask.h b/src/StatTask.h index 00ed9a109..265c5c035 100644 --- a/src/StatTask.h +++ b/src/StatTask.h @@ -13,7 +13,7 @@ class StatTask : public PosixTask StatTask(int sock, uint length); virtual ~StatTask(); - void run(); + bool run(); private: StatTask(); diff --git a/src/TruncateTask.cpp b/src/TruncateTask.cpp index 1e18dc8be..b85c9d3dd 100644 --- a/src/TruncateTask.cpp +++ b/src/TruncateTask.cpp @@ -16,39 +16,40 @@ TruncateTask::~TruncateTask() { } -#define check_error(msg) \ +#define check_error(msg, ret) \ if (!success) \ { \ handleError(msg, errno); \ - return; \ + return ret; \ } -void TruncateTask::run() +bool TruncateTask::run() { bool success; uint8_t buf[1024] = {0}; if (getLength() > 1023) { handleError("TruncateTask read", ENAMETOOLONG); - return; + return false; } success = read(buf, getLength()); - check_error("TruncateTask read"); + check_error("TruncateTask read", false); truncate_cmd *cmd = (truncate_cmd *) buf; int err = ioc->truncate(cmd->filename, cmd->length); if (err) { handleError("TruncateTask truncate", errno); - return; + return true; } sm_msg_resp *resp = (sm_msg_resp *) buf; resp->type = SM_MSG_START; resp->payloadLen = 4; resp->returnCode = 0; - write(buf, sizeof(sm_msg_resp)); + success = write(buf, sizeof(sm_msg_resp)); + return success; } } diff --git a/src/TruncateTask.h b/src/TruncateTask.h index c4d9f6796..871c55d34 100644 --- a/src/TruncateTask.h +++ b/src/TruncateTask.h @@ -13,7 +13,7 @@ class TruncateTask : public PosixTask TruncateTask(int sock, uint length); virtual ~TruncateTask(); - void run(); + bool run(); private: TruncateTask(); diff --git a/src/UnlinkTask.cpp b/src/UnlinkTask.cpp index 9b5484b52..08e4e9710 100644 --- a/src/UnlinkTask.cpp +++ b/src/UnlinkTask.cpp @@ -16,40 +16,41 @@ UnlinkTask::~UnlinkTask() { } -#define check_error(msg) \ +#define check_error(msg, ret) \ if (!success) \ { \ handleError(msg, errno); \ - return; \ + return ret; \ } -void UnlinkTask::run() +bool UnlinkTask::run() { bool success; uint8_t buf[1024] = {0}; if (getLength() > 1023) { handleError("UnlinkTask read", ENAMETOOLONG); - return; + return true; } success = read(buf, getLength()); - check_error("UnlinkTask read"); + check_error("UnlinkTask read", false); unlink_cmd *cmd = (unlink_cmd *) buf; int err = ioc->unlink(cmd->filename); if (err) { handleError("UnlinkTask unlink", errno); - return; + return true; } sm_msg_resp *resp = (sm_msg_resp *) buf; resp->type = SM_MSG_START; resp->payloadLen = 4; resp->returnCode = 0; - write(buf, sizeof(*resp)); + success = write(buf, sizeof(*resp)); + return success; } } diff --git a/src/UnlinkTask.h b/src/UnlinkTask.h index a92521763..3669af62e 100644 --- a/src/UnlinkTask.h +++ b/src/UnlinkTask.h @@ -13,7 +13,7 @@ class UnlinkTask : public PosixTask UnlinkTask(int sock, uint length); virtual ~UnlinkTask(); - void run(); + bool run(); private: UnlinkTask(); diff --git a/src/WriteTask.cpp b/src/WriteTask.cpp index 878e2a4fb..ec19f8493 100644 --- a/src/WriteTask.cpp +++ b/src/WriteTask.cpp @@ -17,31 +17,31 @@ WriteTask::~WriteTask() { } -#define check_error(msg) \ +#define check_error(msg, ret) \ if (!success) \ { \ handleError(msg, errno); \ - return; \ + return ret; \ } #define min(x, y) (x < y ? x : y) -void WriteTask::run() +bool WriteTask::run() { bool success; uint8_t cmdbuf[1024] = {0}; success = read(cmdbuf, sizeof(write_cmd)); - check_error("WriteTask read"); + check_error("WriteTask read", false); write_cmd *cmd = (write_cmd *) cmdbuf; if (cmd->flen > 1023 - sizeof(*cmd)) { handleError("WriteTask", ENAMETOOLONG); - return; + return true; } success = read(&cmdbuf[sizeof(*cmd)], cmd->flen); - check_error("WriteTask read"); + check_error("WriteTask read", false); size_t readCount = 0, writeCount = 0; vector databuf; @@ -52,12 +52,12 @@ void WriteTask::run() { uint toRead = min(cmd->count - readCount, bufsize); success = read(&databuf[0], toRead); - check_error("WriteTask read data"); + check_error("WriteTask read data", false); readCount += toRead; uint writePos = 0; while (writeCount < readCount) { - int err = ioc->append(cmd->filename, &databuf[writePos], toRead - writePos); + int err = ioc->write(cmd->filename, &databuf[writePos], cmd->offset + writeCount, toRead - writePos); if (err <= 0) break; writeCount += err; @@ -75,14 +75,15 @@ void WriteTask::run() resp->payloadLen = 8; resp->returnCode = -1; *((int *) &resp[1]) = errno; - write((uint8_t *) respbuf, sizeof(sm_msg_resp) + 4); + success = write((uint8_t *) respbuf, sizeof(sm_msg_resp) + 4); } else { resp->payloadLen = 4; resp->returnCode = writeCount; - write((uint8_t *) respbuf, sizeof(sm_msg_resp)); + success = write((uint8_t *) respbuf, sizeof(sm_msg_resp)); } + return success; } } diff --git a/src/WriteTask.h b/src/WriteTask.h index a0194f451..4187b9556 100644 --- a/src/WriteTask.h +++ b/src/WriteTask.h @@ -14,7 +14,7 @@ class WriteTask : public PosixTask WriteTask(int sock, uint length); virtual ~WriteTask(); - void run(); + bool run(); private: WriteTask();