I'm trying to implement the Wu-Manber algorithm (http://webglimpse.net/pubs/TR94-17.pdf). From my understanding, the algorithm basically does string matching using a hash table-based approach with chaining to resolve collision.
I have not used C++ for quite a while. I'm looking for help with coding convention for better readability and performance tweaks if possible.
The original algorithm does not work with short patterns (shorter than the value B) so I added the 2 lookups lengthOnePatternLookup_
and lengthTwoPatternLookup_
to help with it.
#ifndef WU_MANBER_HPP #define WU_MANBER_HPP #include <string> #include <vector> #include <stdexcept> #include <algorithm> #include <functional> #include <math.h> namespace wu_manber { namespace { // anonymous namespace, things in here are "private" to wu_manber namespace // fast mod (ref: https://www.youtube.com/watch?v=nXaxk27zwlk&feature=youtu.be&t=56m34s) unsigned int fastmod(const int input, const int ceil) { // apply the modulo operator only when needed return input >= ceil ? input % ceil : input; } } template<typename CharType> class WuManber { public: using StringType = std::basic_string<CharType>; WuManber(unsigned short HBITS = 4, size_t tableSize = 32768) : isInitialized_(false), m_(0), k_(0), HBITS_(HBITS), tableSize_(tableSize) { shiftTable_ = new size_t[tableSize_]; hashPrefixTable_ = new std::vector<PatternHash>[tableSize_]; alphabetSize_ = pow(2, 8 * sizeof(CharType)); isShortPatternExist_ = false; } ~WuManber() { delete []shiftTable_; delete []hashPrefixTable_; if (isShortPatternExist_) { delete []lengthOnePatternLookup_; for (int i = 0; i < alphabetSize_; ++i) { delete []lengthTwoPatternLookup_[i]; } delete []lengthTwoPatternLookup_; } } WuManber(const WuManber&) = delete; WuManber& operator =(const WuManber&) = delete; const std::vector<StringType>& patternList() const { return patternList_; } void preProcess(const std::vector<StringType> &patterns) { m_ = 0; for (const auto &pattern : patterns) { size_t patternLength = pattern.size(); if (patternLength < B_) { if (patternLength == 1) { lengthOnePatterns_.push_back(pattern); } else if (patternLength == 2) { lengthTwoPatterns_.push_back(pattern); } continue; } m_ = (!m_) ? patternLength : std::min(patternLength, m_); patternList_.push_back(pattern); } k_ = patternList_.size(); // fill default value for SHIFT table for (int i = 0; i < tableSize_; ++i) { shiftTable_[i] = m_ - B_ + 1; } // fill HASH/PREFIX and SHIFT tables for (size_t i = 0; i < k_; ++i) { for (size_t j = m_; j >= B_; --j) { unsigned int hashValue; hashValue = patternList_[i][j - 1]; hashValue <<= HBITS_; hashValue += patternList_[i][j - 1 - 1]; hashValue <<= HBITS_; hashValue += patternList_[i][j - 2 - 1]; hashValue = fastmod(hashValue, tableSize_); size_t shiftLength = m_ - j; shiftTable_[hashValue] = std::min(shiftTable_[hashValue], shiftLength); if (!shiftLength) { PatternHash patternHashToAdd; patternHashToAdd.idx = i; // calculate this prefixHash to help us skip some patterns if there are collisions in hashPrefixTable_ patternHashToAdd.prefixHash = patternList_[i][0]; patternHashToAdd.prefixHash <<= HBITS_; patternHashToAdd.prefixHash += patternList_[i][1]; hashPrefixTable_[hashValue].push_back(patternHashToAdd); } } } isShortPatternExist_ = (lengthOnePatterns_.size() > 0) || (lengthTwoPatterns_.size() > 0); if (isShortPatternExist_) { lengthOnePatternLookup_ = new int[alphabetSize_]; lengthTwoPatternLookup_ = new int*[alphabetSize_]; for (int i = 0; i < alphabetSize_; ++i) { lengthOnePatternLookup_[i] = -1; lengthTwoPatternLookup_[i] = new int[alphabetSize_]; for (int j = 0; j < alphabetSize_; ++j) { lengthTwoPatternLookup_[i][j] = -1; } } for (int i = 0; i < lengthOnePatterns_.size(); ++i) { lengthOnePatternLookup_[(size_t)lengthOnePatterns_[i][0]] = i; } for (int i = 0; i < lengthTwoPatterns_.size(); ++i) { lengthTwoPatternLookup_[(size_t)lengthTwoPatterns_[i][0]][(size_t)lengthTwoPatterns_[i][1]] = i; } } isInitialized_ = true; } // onMatch takes 3 arguments: the matched pattern, the pattern's index in the pattern list, the start index of the match in text void scan(const StringType &text, std::function<void(const StringType&, size_t, size_t)> onMatch) { size_t textLength = text.size(); if (!isInitialized_ || textLength == 0) { return; } if (isShortPatternExist_) { int firstCharacterMatchIndex = lengthOnePatternLookup_[(size_t)text[0]]; if (firstCharacterMatchIndex > -1) { onMatch(lengthOnePatterns_[firstCharacterMatchIndex], firstCharacterMatchIndex, 0); } const int PRE_WU_MANBER_LIMIT = std::min(m_ - 1, textLength); for (int idx = 1; idx < PRE_WU_MANBER_LIMIT; ++idx) { CharType preChar = text[idx - 1]; CharType curChar = text[idx]; checkShortPattern_(text, idx, onMatch); } } size_t idx = m_ - 1; while (idx < textLength) { if (isShortPatternExist_) { checkShortPattern_(text, idx, onMatch); } // hash value for HASH table unsigned int hashValue; hashValue = text[idx]; hashValue <<= HBITS_; hashValue += text[idx - 1]; hashValue <<= HBITS_; hashValue += text[idx - 2]; hashValue = fastmod(hashValue, tableSize_); size_t shiftLength = shiftTable_[hashValue]; if (shiftLength == 0) { // found a potential match, check values in HASH/PREDIX and will shift by 1 character shiftLength = 1; // hash value to match pattern unsigned int prefixHash; prefixHash = text[idx - m_ + 1]; prefixHash <<= HBITS_; prefixHash += text[idx - m_ + 2]; for (const auto &potentialMatch : hashPrefixTable_[hashValue]) { if (prefixHash == potentialMatch.prefixHash) { bool isMatched = false; const StringType &pattern = patternList_[potentialMatch.idx]; size_t idxInPattern = 0; size_t idxInText = idx - m_ + 1; size_t patternLength = pattern.size(); // prefix hash matched so we try to match character by character while(idxInPattern < patternLength && idxInText < textLength && pattern[idxInPattern++] == text[idxInText++]); // end of pattern reached => match found if (idxInPattern == patternLength) { onMatch(pattern, potentialMatch.idx, idx - m_ + 1); } } } } if (isShortPatternExist_) { ++idx; } else { idx += shiftLength; } } } private: // block size // the paper says in practice, we use either B = 2 or B = 3 // we'll use 3 const size_t B_ = 3; // min pattern size size_t m_; // number of patterns to be processed by Wu - Manber size_t k_; // number of bits to shift when hashing // the paper says it use 5 unsigned short HBITS_; // size of HASH and SHIFT tables size_t tableSize_; // SHIFT table size_t* shiftTable_; // store index in pattern list + prefix hash value for each pattern struct PatternHash { unsigned int prefixHash; size_t idx; }; // HASH + PREFIX table std::vector<PatternHash>* hashPrefixTable_; // pattern list std::vector<StringType> patternList_; // handle length 1 and 2 patterns bool isShortPatternExist_; size_t alphabetSize_; int* lengthOnePatternLookup_; int** lengthTwoPatternLookup_; std::vector<StringType> lengthOnePatterns_; std::vector<StringType> lengthTwoPatterns_; bool isInitialized_; void checkShortPattern_(const StringType &text, size_t cur_idx, std::function<void(const StringType&, size_t, size_t)> onMatch) const { int l1MatchIndex = lengthOnePatternLookup_[(size_t)text[cur_idx]]; if (l1MatchIndex > -1) { onMatch(lengthOnePatterns_[l1MatchIndex], l1MatchIndex, cur_idx); } int l2MatchIndex = lengthTwoPatternLookup_[(size_t)text[cur_idx - 1]][(size_t)text[cur_idx]]; if (l2MatchIndex > -1) { onMatch(lengthTwoPatterns_[l2MatchIndex], l2MatchIndex, cur_idx - 1); } } }; } // namespace wu_manber #endif // WU_MANBER_HPP
If the code is too long, I have pushed it to github if you find reading from there easier (https://github.com/bubiche/wu_manber/blob/master/wu_manber.hpp)
Thank you for your help!