Month: May 2013

Trie 헤더 파일

연구참여 하면서 VIsual c++로 간단히 작성했던 Trie. 묵혀두기 아까워서 공유 -_-aㅋㅋㅋ

string을 key로 하는 dictionary type을 만들고 싶을 때, STL의 map보다 월등히 높은 성능(특히 메모리)을 나타냅니다.

iterator, get/put/erase 모두 지원합니다 :)

Memory leak이 없는 것은 확인 하였고, Class를 value로 하는 타입에 대해서는 프로그램이 제대로 동작하지 않을 수도 있습니다 -,.-

Key로 한국어도 가능합니다.

 

Trie.h :

/*
author : akdal
https://agidari.wordpress.com
*/
#ifndef TRIE_H

#define TRIE_H

#include<vector>
#include<assert.h>
#include<algorithm>
#include<iostream>
using namespace std;

template<class T>
class Trie{
public:
        //static int NODE_ALLOCCNT;
        //static int NODE_DEALLOCCNT;
protected:
        class Node{
        public:
                T value;
                bool hasValue;
                vector<pair<char, Node *> > *children;
                bool isLeaf;
                bool isConstant;

                Node(T &value){
                        this->hasValue = true;
                        this->value = value;
                        this->isLeaf = true;
                        this->chlidren = NULL;
                        this->isConstant = false;
                        //Trie<T>::NODE_ALLOCCNT++;
                }
                Node(){
                        this->hasValue = false;
                        this->isLeaf = true;
                        this->children = NULL;
                        this->isConstant = false;
                        //Trie<T>::NODE_ALLOCCNT++;
                }
                Node(const Node &othernode){
                        this->value = othernode.value;
                        this->hasValue = othernode.hasValue;
                        this->isLeaf = othernode.isLeaf;
                        this->isConstant = othernode.isConstant;
                        this->children = NULL;
                        if(othernode.children != NULL){
                                this->children = new vector<pair<char, Node *> >();
                                for(vector<pair<char, Node *> >::iterator itr = othernode.children->begin(); itr != othernode.children->end(); itr++){
                                        pair<char, Node *> &p = *itr;
                                        pair<char, Node *> newp;
                                        newp.first = p.first;
                                        newp.second = new Node(*p.second);
                                        this->children->push_back(newp);
                                }
                        }
                        //Trie<T>::NODE_ALLOCCNT++;
                }
                ~Node(){
                        int i;
                        if(!isLeaf){
                                for(i = 0; i < children->size(); i++)
                                        delete (*children)[i].second;
                                delete children;
                        }
                        //Trie<T>::NODE_DEALLOCCNT++;
                }
                void makeConstant(){
                        this->isConstant = true;
                        if(isLeaf)
                                return;
                        int i, j;
                        for(i = 0; i < this->children->size(); i++){
                                for(j = i + 1 j < this->children->size(); j++){
                                        if(children->at(i).first > children->at(j).first)
                                                swap(children->at(i), children->at(j));
                                }
                        }
                        for(i = 0; i < this->children->size(); i++)
                                this->children->at(i).second.makeConstant();
                }
                Node *getNode(char c){
                        if(this->isLeaf)
                                return NULL;

                        if(this->isConstant){
                                int pos;
                                int sz = children->size();
                                int st= 0, ed = sz - 1;
                                while(true){
                                        if(st == ed){
                                                pair<char, Node *> &p = children->at(st);
                                                if(p.first == c)
                                                        return p.second;
                                                else
                                                        return NULL;
                                        }else if(st == ed - 1){
                                                pair<char, Node *> &p = children->at(st);
                                                if(p.first == c)
                                                        return p.second;
                                                else
                                                        st = ed;
                                        }else{
                                                int piv = (st + ed) / 2;
                                                pair<char, Node *> &p = children->at(piv);
                                                if(p.first == c)
                                                        return p.second;
                                                else if(p.first < c){
                                                        ed= piv - 1;
                                                }else{
                                                        st = piv + 1;
                                                }
                                        }
                                }
                                return NULL;
                        }else{
                                vector<pair<char, Node*> >::iterator iend = children->end();
                                for(vector<pair<char, Node *> >::iterator itr = children->begin(); itr != iend; itr++){
                                        pair<char, Node *> &p = *itr;
                                        if(p.first == c)
                                                return p.second;
                                }
                                return NULL;
                        }
                }
                void deleteNode(char c){
                        assert(!isConstant);
                        if(this->isLeaf){
                                assert(false);
                                return;
                        }
                        vector<pair<char, Node *> >::iterator iend =children->end();
                        for(vector<pair<char, Node *> >::iterator itr = children->begin(); itr != iend; itr++){
                                pair<char, Node *> &p = *itr;
                                if(p.first == c){
                                        pair<char, Node *> p = *itr;
                                        children->erase(itr);
                                        assert(p.second != NULL);
                                        delete p.second;
                                        break;
                                }
                        }
                        if(this->children->size() == 0)
                                this->isLeaf = true;
                }
                Node *childAt(int i)
                { return this->children->at(i).second; }
                char edgeAt(int i)
                { return this->children->at(i).first; }
                int size()
                { return this->children->size(); }
                Node *addNode(char c){
                        assert(!isConstant);
                        if(this->isLeaf){
                                this->isLeaf = false;
                                this->children = new vector<pair<char, Node *> >();
                        }
                        Node *n = new Node();
                        this->children->push_back(pair<char, Node *>(c, n));
                        return n;
                }
                void print(ostream &ostr, int space){
                        int i;
                        if(this->hasValue){
                                for(i = 0; i < space; i++)
                                        ostr << (" ");
                                ostr << "VALUE='" << this->value << "'" << endl;
                        }
                        if(!this->isLeaf)
                                for(i = 0; i < this->children->size(); i++){
                                        int j;
                                        for(j = 0; j < space; j++)
                                                ostr << (" ");
                                        ostr << "edge '" << this->children->at(i).first << "'" << endl;
                                        this->children->at(i).second->print(ostr, space + 1);
                                }
                }
        };

public:
        class iterator{
        private:
                char *trace;
                int *stack;
                Node **nodeStack;
                //int stackSize;
                int nodeStackSize;
                int maxStackSize;
                bool end;
                Node *root;

        protected:

                void _nextItr(){
                        if(end)
                                return;
                        if(this->nodeStackSize == 0){
                                this->nodeStackSize = 1;
                                this->nodeStack[0] = root;
                                this->trace[0] = 0;
                                return;
                        }
                        if(!this->nodeStack[this->nodeStackSize - 1]->isLeaf){
                                // ?????!
                                this->nodeStack[this->nodeStackSize] = this->nodeStack[this->nodeStackSize - 1]->childAt(0);
                                this->trace[this->nodeStackSize - 1] = this->nodeStack[this->nodeStackSize - 1]->edgeAt(0);
                                this->trace[this->nodeStackSize] = 0;
                                this->stack[this->nodeStackSize -1] = 0;
                                this->nodeStackSize++;
                        }else{
                                // ?????!
                                while(true){
                                        this->nodeStackSize--;
                                        if(this->nodeStackSize == 0){
                                                end = true;
                                                break;
                                        }
                                        // this->stack[this->nodeStackSize - 1]?? ????. this->nodeStackSize[this->nodeStackSize - 1]->children size??
                                        // ????? ???? fold.
                                        this->stack[this->nodeStackSize - 1]++;
                                        int idx = this->stack[this->nodeStackSize - 1];
                                        if(idx < this->nodeStack[this->nodeStackSize - 1]->children->size()){
                                                this->nodeStack[this->nodeStackSize] = this->nodeStack[this->nodeStackSize - 1]
                                                                ->childAt(idx);
                                                this->trace[this->nodeStackSize - 1] = this->nodeStack[this->nodeStackSize - 1]->edgeAt(idx);
                                                this->trace[this->nodeStackSize] = 0;
                                                this->nodeStackSize++;
                                                break;
                                        }
                                }
                        }
                }
        public:
                iterator(int maxStackSize, Node *root){
                        this->maxStackSize = maxStackSize;
                        this->nodeStackSize = 0;
                        //this->stackSize = 0;

                        this->trace = new char[this->maxStackSize + 1];
                        this->stack = new int[this->maxStackSize];
                        this->nodeStack = new Trie<T>::Node *[this->maxStackSize + 1];
                        this->end = false;
                        this->root = root;
                }
                iterator(const iterator &other){
                        this->maxStackSize = other.maxStackSize;
                        this->nodeStackSize = other.nodeStackSize;
                        this->root = other.root;
                        this->end = other.end;

                        this->trace = new char[this->maxStackSize + 1];
                        memcpy(this->trace, other.trace, this->maxStackSize + 1);;
                        this->stack = new int[this->maxStackSize];
                        memcpy(this->stack, other.stack, this->maxStackSize * sizeof(int));
                        this->nodeStack = new Node*[this->maxStackSize + 1];
                        memcpy(this->nodeStack, other.nodeStack, (this->maxStackSize + 1) * sizeof(int));
                }

                iterator& operator=(const iterator &other){
                        if(this->nodeStack)
                                delete this->nodeStack;
                        if(this->stack)
                                delete this->stack;
                        if(this->trace)
                                delete this->trace;

                        this->maxStackSize = other.maxStackSize;
                        this->nodeStackSize = other.nodeStackSize;
                        this->root = other.root;
                        this->end = other.end;

                        this->trace = new char[this->maxStackSize + 1];
                        memcpy(this->trace, other.trace, this->maxStackSize + 1);;
                        this->stack = new int[this->maxStackSize];
                        memcpy(this->stack, other.stack, this->maxStackSize * sizeof(int));
                        this->nodeStack = new Node*[this->maxStackSize + 1];
                        memcpy(this->nodeStack, other.nodeStack, (this->maxStackSize + 1) * sizeof(int));

                        return *this;
                }
                iterator(iterator &&other){
                        this->trace = other.trace;
                        this->stack = other.stack;
                        this->nodeStack = other.nodeStack;
                        other.trace = NULL;
                        other.stack = NULL;
                        other.nodeStack = NULL;

                        this->maxStackSize = other.maxStackSize;
                        this->nodeStackSize = other.nodeStackSize;
                        this->root = other.root;
                        this->end = other.end;
                }
                ~iterator(){
                        if(trace)
                                delete trace;
                        if(stack)
                                delete stack;
                        if(nodeStack)
                                delete nodeStack;
                }

                bool isEnd()
                { return this->end; }
                T *next(){
                        if(end)
                                return NULL;

                        do{
                                _nextItr();
                        }while(!end && !this->nodeStack[this->nodeStackSize - 1]->hasValue);

                        T *answer;
                        if(this->nodeStackSize == 0)
                                answer = NULL;
                        else
                                answer=   &this->nodeStack[this->nodeStackSize - 1]->value;

                        return answer;
                }
                const char *getCurrentTrace()
                { return this->trace; }
        };

private:
        Node *rootNode;
        int maxDepth;
        int elemSize;
        bool doConstant;
public:
        Trie(){
                this->rootNode = new Node();
                this->maxDepth = 1;
                this->doConstant = false;
                this->elemSize= 0;
        }
        Trie(const Trie &orgtrie){
                this->rootNode = new Node(*(orgtrie.rootNode));
                this->maxDepth = orgtrie.maxDepth;
                this->elemSize = orgtrie.elemSize;
                this->doConstant = false;
        }
        void makeConstant()
        { this->doConstant = true; }
        void put(const char *letters, const T& value);
        void put(const string &str, const T& value);
        void erase(const string &str){
                assert(!doConstant);
                this->_erase(str.c_str(), this->rootNode);
                this->elemSize--;
        }
        T *get(const char *letters);
        T *get(const string &str);
        typename Trie<T>::iterator itr();
        void clear();
        int size() const
        { return this->elemSize; }

        void print(ostream &st)
        {this->rootNode->print(st, 0); }

protected:
        void _put(const char *letters, Node *focus, const T& t);
        // returns : node???? *chrs?? ????? ???? ???? node ????? ?????? ???°? ???°??
        bool _erase(const char *chrs, Node *node){
                if(*chrs == 0){
                        node->hasValue = false;
                        return node->isLeaf;
                }
                Trie<T>::Node *n = node->getNode(*chrs);
                assert(n != NULL);
                if(this->_erase(chrs + 1, n) == true){
                        node->deleteNode(*chrs);
                }
                if(node->isLeaf && !node->hasValue)
                        return true;
                return false;
        }
        T *_get(const char *letters, Node *focus);
        T *_get(const string &str, int idx, int len, Node *focus);
};

#include"Trie.h"
#include<assert.h>

template<class T>
void Trie<T>::_put(const char *letters, Node *focus, const T& t){
        if(letters[0] == 0){
                focus->hasValue = true;
                focus->value = t;
        }else{
                Trie<T>::Node *n = focus->getNode(letters[0]);
                if(n == NULL)
                        this->_put(letters + 1, focus->addNode(letters[0]), t);
                else
                        this->_put(letters + 1, n, t);
        }
}
template<class T>
void Trie<T>::put(const char *letters, const T &t)
{
        assert(!this->doConstant);
        int len = strlen(letters) + 1;
        this->elemSize++;
        if(len > this->maxDepth)
                this->maxDepth = len;
        this->_put(letters, this->rootNode, t);
}
template<class T>
void Trie<T>::put(const string &str, const T& value){
        assert(!this->doConstant);
        this->put(str.c_str(), value);
}

template<class T>
T* Trie<T>::_get(const char *letters, Node *node){
        if(letters[0] == 0)
                return node->hasValue ? &(node->value) : NULL;
        Trie<T>::Node *n = node->getNode(letters[0]);
        if(n == NULL)
                return NULL;
        return this->_get(letters + 1, n);
}
template<class T>
T* Trie<T>::_get(const string &str, int idx, int len, Node *node){
        if(idx == len)
                return node->hasValue ? &(node->value) : NULL;
        Trie<T>::Node *n = node->getNode(str.at(idx));
        if(n == NULL)
                return NULL;
        return this->_get(str, idx + 1, len, n);
}

template<class T>
T *Trie<T>::get(const char *letters)
{ return this->_get(letters, this->rootNode); }
template<class T>
T *Trie<T>::get(const string &letters)
{ return this->_get(letters, 0, letters.length(), this->rootNode); }
/*
template<class T>
void Trie<T>::erase(const string &str){
        this->_erase(s.c_str(), this->rootNode);
}
*/
template<class T>
typename Trie<T>::iterator Trie<T>::itr(){
        return Trie<T>::iterator(this->maxDepth, tihis->rootNode);
}

template<class T>
void Trie<T>::clear(){
        assert(!this->doConstant);
        delete this->rootNode;
        this->maxDepth = 0;
        this->rootNode = new Node();
        this->elemSize = 0;
}

#endif

Trie_test.cpp :

/*
 * Author : akdal
 * https://agidari.wordpress.com
 * */
#include"Trie.h"
#include<iostream>
using namespace std;

int main(){
	Trie<int> t;
	Trie<bool> tt;
	t.put("aaaaa", 5);
	t.put("a", 1);
	t.put("ab", 2);
	t.put("bc", 2);
	t.put("b", 1);
	t.put("테스트", 4); // Korean is also possible.
	t.put("테스블", 5);

	int *p_aaa = t.get("aaa");
	int *p_a = t.get("a");
	int *p_dd = t.get("dd");
	int *p_bc = t.get("bc");
	int *p_ab = t.get("ab");
	int *p_b = t.get("b");

	t.print(cout);
	Trie<int>::iterator itr = t.itr();
	while(!itr.isEnd()){
		int *val = itr.next();
		if(val == NULL)
			break;
		cout << itr.getCurrentTrace() << " : " << *val << endl;
	}
	t.erase("bc");
	t.print(cout);
	cout<<"--"<<endl;
	t.erase("a");
	t.erase("aaaaa");
	t.print(cout);
	cout<<"--"<<endl;
	return 0;
}