1 #ifndef STATE_MACHINE_HPP 2 #define STATE_MACHINE_HPP 14 #include <unordered_map> 68 using ConnectionID =
typename Identifier::ConnectionID;
69 using Hasher =
typename Identifier::Hasher;
70 static constexpr
auto timeoutIDInvalid = std::numeric_limits<uint32_t>::max();
83 using stateFun = std::function<void(State &, Packet *, FunIface &)>;
86 using timeoutFun = std::function<void(State &, FunIface &)>;
133 bool immediateTransition;
138 : sm(sm), pktIdx(pktIdx), pktsBA(pktsBA), cID(cID), state(state), sendPkt(true),
139 immediateTransition(false){};
149 if (immediateTransition) {
166 assert((sm->functions.size() - 1) >= state.
state);
167 auto fun = sm->functions[state.
state];
168 assert(fun !=
nullptr);
170 D(std::cout <<
"Running Function" << std::endl;)
171 fun(state, pktsBA[pktIdx], *
this);
173 if (state.
state == sm->endStateID) {
174 D(std::cout <<
"Reached endStateID - deleting connection" << std::endl;)
175 sm->removeState(cID);
188 throw new std::runtime_error(
"StateMachine::FunIface::getPkt() not implemented");
211 state.
state = newState;
212 this->immediateTransition =
true;
220 std::chrono::time_point<std::chrono::steady_clock> now =
221 std::chrono::steady_clock::now();
222 std::chrono::time_point<std::chrono::steady_clock> then = now + timeout;
226 t.timeoutID = sm->curTimeoutID++;
230 sm->timeoutsQ.push(t);
241 sm->timeoutFunctions.emplace(
242 t.timeoutID, std::make_unique<struct TimeoutData>(cID, fun));
248 const unsigned int numBuckets = 8;
249 const uint64_t bucketMask = 0b111;
251 std::unordered_map<ConnectionID, State, Hasher> **newStates;
256 newStates =
reinterpret_cast<std::unordered_map<ConnectionID, State, Hasher> **
>(
257 malloc(
sizeof(
void *) * numBuckets));
258 for (
unsigned int i = 0; i < numBuckets; i++) {
259 newStates[i] =
new std::unordered_map<ConnectionID, State, Hasher>();
263 reinterpret_cast<SpinLockCLSize **
>(malloc(
sizeof(
void *) * numBuckets));
264 for (
unsigned int i = 0; i < numBuckets; i++) {
270 for (
unsigned int i = 0; i < numBuckets; i++) {
271 delete (newStates[i]);
275 for (
unsigned int i = 0; i < numBuckets; i++) {
276 newStatesLock[i]->
lock();
277 delete (newStatesLock[i]);
289 uint64_t bucket = hasher(cID) & bucketMask;
291 std::lock_guard<SpinLockCLSize> lock(*(newStatesLock[bucket]));
292 auto insPair = std::pair<ConnectionID, State>(cID, st);
293 newStates[bucket]->insert(insPair);
308 uint64_t bucket = hasher(cID) & bucketMask;
310 std::lock_guard<SpinLockCLSize> lock(*(newStatesLock[bucket]));
311 auto ret = newStates[bucket]->find(cID);
313 if (ret != newStates[bucket]->end()) {
314 st->
set(ret->second);
315 newStates[bucket]->erase(ret);
329 void erase(ConnectionID &cID);
342 std::unordered_map<ConnectionID, State, Hasher> stateTable;
346 std::vector<stateFun> functions;
350 Identifier identifier;
358 std::function<void *(ConnectionID)> startStateFun;
364 std::function<Packet *()> getPktCB;
367 bool listenToConnections;
378 std::chrono::time_point<std::chrono::steady_clock> time;
386 bool operator()(
struct Timeout a,
struct Timeout b) {
return a.time < b.time; };
391 uint32_t curTimeoutID;
394 std::priority_queue<struct Timeout, std::vector<struct Timeout>,
395 typename Timeout::Compare>
409 TimeoutData(ConnectionID
id,
timeoutFun fun) : id(id), fun(fun){};
417 std::unordered_map<uint32_t, std::unique_ptr<struct TimeoutData>> timeoutFunctions;
427 static ConnectionPool connPoolStatic;
428 ConnectionPool *connPool;
436 uint64_t stat_statesAdded = 0;
437 uint64_t stat_statesClosed = 0;
445 auto findState(ConnectionID
id) {
447 auto stateIt = stateTable.find(
id);
448 if (stateIt == stateTable.end()) {
453 if (connPool->findAndErase(
id, &st)) {
454 D(std::cout <<
"StateMachine::findState() found state in connPool" 456 stateTable.insert({id, st});
465 if (listenToConnections) {
467 D(std::cout <<
"Adding new state" << std::endl;)
468 D(std::cout <<
"ConnectionID: " << static_cast<std::string>(
id) << std::endl;)
471 void *stateData =
nullptr;
473 stateData = startStateFun(
id);
476 State s(startStateID, stateData);
477 stateTable.insert({id, s});
485 D(std::cout <<
"State found" << std::endl;)
491 D(std::cout << std::endl <<
"StateMachine::runPkt() called" << std::endl;)
495 Packet *pktIn = pktsIn[cur];
498 ConnectionID identity = identifier.identify(pktIn);
501 auto stateIt = findState(identity);
503 if (stateIt == stateTable.end()) {
505 D(std::cout <<
"StateMachine::runPkt() discarding packet" << std::endl;)
506 D(std::cout <<
"ident of packet: " << static_cast<std::string>(identity)
512 if (stateIt->second.timeoutID != timeoutIDInvalid) {
513 this->timeoutFunctions.erase(stateIt->second.timeoutID);
514 stateIt->second.timeoutID = timeoutIDInvalid;
526 assert((functions.size() - 1) >= stateIt->second.state);
527 auto fun = functions[stateIt->second.state];
528 assert(fun !=
nullptr);
531 FunIface funIface(
this, cur, pktsIn, identity, stateIt->second);
534 D(std::cout <<
"StateMachine::runPkt() Running Function" << std::endl;)
535 D(std::cout <<
"StateMachine::runPkt() identity: " 536 << static_cast<std::string>(identity) << std::endl;)
537 D(std::cout <<
"StateMachine::runPkt() hexdump of packet: " << std::endl;)
538 D(
hexdump(pktIn->getData(), pktIn->getDataLen());)
540 fun(stateIt->second, pktIn, funIface);
543 if (stateIt->second.state == endStateID) {
545 <<
"StateMachine::runPkt() Reached endStateID - deleting connection" 558 D(std::cout <<
"StateMachine::runPkt() Packet could not be identified" 572 : startStateID(0), endStateID(
StateIDInvalid), listenToConnections(false),
573 curTimeoutID(0), connPool(&connPoolStatic){};
576 std::cout <<
"StateMachine stats:" << std::endl;
577 std::cout <<
"stateTable.size() = " << stateTable.size() << std::endl;
578 std::cout <<
"statesAdded = " << stat_statesAdded << std::endl;
579 std::cout <<
"statesClosed = " << stat_statesClosed << std::endl;
599 if ((static_cast<int>(functions.size()) - 1) <
static_cast<int>(id)) {
600 functions.resize(
id + 1);
602 functions[id] =
function;
622 StateID startStateID, std::function<
void *(ConnectionID)> startStateFun) {
623 this->startStateID = startStateID;
624 this->startStateFun = startStateFun;
625 listenToConnections =
true;
661 assert((functions.size() - 1) >= st.state);
662 auto fun = functions[st.state];
663 assert(fun !=
nullptr);
665 FunIface funIface(
this, 0, pktsIn,
id, st);
667 D(std::cout <<
"StateMachine::addState() Running Function" << std::endl;)
669 fun(st, pktsIn[0], funIface);
671 if (st.state == endStateID) {
672 D(std::cout <<
"StateMachine::addState() Reached endStateID - deleting connection" 677 D(std::cout <<
"StateMachine::addState() adding connection to newStates" 679 D(std::cout <<
"StateMachine::addState() identity: " << static_cast<std::string>(
id)
681 connPool->add(
id, st);
695 D(std::cout << std::endl
696 <<
"StateMachine::runPktBatch() running incoming batch now, #Pkts: " 697 << inCount << std::endl;)
702 while (!timeoutsQ.empty()) {
704 auto timeoutElem = timeoutsQ.top();
707 if (timeoutElem.time > std::chrono::steady_clock::now()) {
712 uint32_t timeoutID = timeoutElem.timeoutID;
713 auto timeoutDataIt = timeoutFunctions.find(timeoutID);
716 if (timeoutDataIt == timeoutFunctions.end()) {
722 std::unique_ptr<struct TimeoutData> timeoutData =
723 std::move(timeoutDataIt->second);
724 auto stateIt = findState(timeoutData->id);
725 assert(stateIt != stateTable.end());
726 FunIface funIface(
this, std::numeric_limits<uint32_t>::max(), pktsIn,
727 timeoutData->id, stateIt->second);
730 stateIt->second.timeoutID = timeoutIDInvalid;
733 timeoutData->fun(stateIt->second, funIface);
736 if (stateIt->second.state == endStateID) {
737 D(std::cout <<
"Reached endStateID - deleting connection" << std::endl;)
743 timeoutFunctions.erase(timeoutDataIt);
747 for (uint32_t i = 0; i < inCount; i++) {
756 template <
class Identifier,
class Packet>
void runPktBatch(BufArray< Packet > &pktsIn)
Run a batch of packets.
void removeState(ConnectionID id)
Remove a connection.
Main interface for the needs of a state function.
Packet * getPkt()
Get an additional packet buffer.
void markDropPkt(uint32_t pktIdx)
Mark one packet as drop.
void transitionNow(StateID newState)
Immediately transition to another state.
std::function< void(State &, mbuf *, FunIface &)> stateFun
This is the signature any state function needs to expose.
State(StateID state, void *stateData)
void hexdump(const void *data, int dataLen)
Dump hex data.
uint32_t getTotalCount() const
Get the number of all packets in the BufArray.
void registerGetPktCB(std::function< Packet *()> fun)
Register a callback in order to get new buffer.
bool findAndErase(ConnectionID &cID, State *st)
Try to find a state for a given connection ID.
void freePkt()
Free the packet after the batch is processed, do not send it.
bool operator()(struct Timeout a, struct Timeout b)
void addState(ConnectionID id, State st, BufArray< Packet > &pktsIn)
Open an outgoing connection.
void registerStartStateID(StateID startStateID, std::function< void *(ConnectionID)> startStateFun)
This method describes, how to proceed with incoming connections.
void add(ConnectionID &cID, State &st)
Add connection and state to the connection pool.
static constexpr auto StateIDInvalid
Represents an invalid StateID.
void setConnectionPool(ConnectionPool *cp)
void registerFunction(StateID id, stateFun function)
Register a function for a given state.
std::function< void(State &, FunIface &)> timeoutFun
This is the signature any timeout function needs to expose.
size_t getStateTableSize()
Get the number of tracked connections This is probably only for statistics.
Represents one connection.
void setTimeout(std::chrono::milliseconds timeout, timeoutFun fun)
Set a timeout, after which a transition will happen.
Wrapper around MoonGen bufarrays.
void registerEndStateID(StateID endStateID)
This registers an end state If a connections reaches this id, it will be destroyed.
void transition(StateID newState)
Transition to another state.