Skip to content

Commit

Permalink
cleaner inheritence
Browse files Browse the repository at this point in the history
  • Loading branch information
emjotde committed Jan 24, 2018
1 parent a1e22b1 commit 6e223c2
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/data/corpus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ Corpus::Corpus(std::vector<std::string> paths,
std::vector<Ptr<Vocab>> vocabs,
Ptr<Config> options,
size_t maxLength)
: DatasetBase(paths),
: CorpusBase(paths),
options_(options),
vocabs_(vocabs),
maxLength_(maxLength ? maxLength : options_->get<size_t>("max-length")),
Expand Down
9 changes: 8 additions & 1 deletion src/data/corpus.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,14 @@ class CorpusBatch : public Batch {

class CorpusIterator;

typedef DatasetBase<SentenceTuple, CorpusIterator, CorpusBatch> CorpusBase;
class CorpusBase : public DatasetBase<SentenceTuple, CorpusIterator, CorpusBatch> {
public:

CorpusBase() : DatasetBase() {}
CorpusBase(std::vector<std::string> paths) : DatasetBase(paths) {}

virtual std::vector<Ptr<Vocab>>& getVocabs() = 0;
};

class CorpusIterator
: public boost::iterator_facade<CorpusIterator,
Expand Down
2 changes: 1 addition & 1 deletion src/data/corpus_sqlite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ CorpusSQLite::CorpusSQLite(std::vector<std::string> paths,
std::vector<Ptr<Vocab>> vocabs,
Ptr<Config> options,
size_t maxLength)
: DatasetBase(paths),
: CorpusBase(paths),
options_(options),
vocabs_(vocabs),
maxLength_(maxLength ? maxLength : options_->get<size_t>("max-length")),
Expand Down
9 changes: 1 addition & 8 deletions src/training/training.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,7 @@ class Train : public ModelTask {
if((options_->has("valid-sets") || options_->has("valid-script-path"))
&& options_->get<size_t>("valid-freq") > 0) {

// @TODO: solve this with better polymorphism
std::vector<Ptr<Vocab>> vocabs;
if(options_->get<bool>("sqlite"))
vocabs = std::static_pointer_cast<CorpusSQLite>(dataset)->getVocabs();
else
vocabs = std::static_pointer_cast<Corpus>(dataset)->getVocabs();

for(auto validator : Validators(vocabs, options_))
for(auto validator : Validators(dataset->getVocabs(), options_))
scheduler->addValidator(validator);
}

Expand Down

0 comments on commit 6e223c2

Please sign in to comment.