From 071b8fb6398ee18b73042c6fea72ded55080be4e Mon Sep 17 00:00:00 2001 From: marco-nicola Date: Fri, 11 Dec 2020 11:14:59 +0100 Subject: [PATCH] Replace models.TokenOffsets with strutils.ByteOffsets --- models/bpemodel/bpemodel.go | 3 ++- models/bpemodel/bpemodel_test.go | 21 ++++++++++---------- models/models.go | 9 +++------ models/wordpiecemodel/wordpiecemodel.go | 7 ++++--- models/wordpiecemodel/wordpiecemodel_test.go | 19 +++++++++--------- 5 files changed, 30 insertions(+), 29 deletions(-) diff --git a/models/bpemodel/bpemodel.go b/models/bpemodel/bpemodel.go index de2f0b1..308f9fd 100644 --- a/models/bpemodel/bpemodel.go +++ b/models/bpemodel/bpemodel.go @@ -7,6 +7,7 @@ package bpemodel import ( "fmt" "github.com/nlpodyssey/gotokenizers/models" + "github.com/nlpodyssey/gotokenizers/strutils" "github.com/nlpodyssey/gotokenizers/vocabulary" ) @@ -193,7 +194,7 @@ func (m *BPEModel) wordToTokens(word *Word) ([]models.Token, error) { tokens[i] = models.Token{ ID: wordSymbol.ID, Value: value, - Offsets: models.TokenOffsets{ + Offsets: strutils.ByteOffsets{ Start: offsetStart, End: offsetEnd, }, diff --git a/models/bpemodel/bpemodel_test.go b/models/bpemodel/bpemodel_test.go index c98948d..78a5da7 100644 --- a/models/bpemodel/bpemodel_test.go +++ b/models/bpemodel/bpemodel_test.go @@ -6,6 +6,7 @@ package bpemodel import ( "github.com/nlpodyssey/gotokenizers/models" + "github.com/nlpodyssey/gotokenizers/strutils" "github.com/nlpodyssey/gotokenizers/vocabulary" "reflect" "testing" @@ -81,7 +82,7 @@ func TestTokenizeWithAndWithoutDropout(t *testing.T) { { ID: 15, Value: "unrelated", - Offsets: models.TokenOffsets{Start: 0, End: 9}, + Offsets: strutils.ByteOffsets{Start: 0, End: 9}, }, } if !reflect.DeepEqual(tokens, expectedTokens) { @@ -105,15 +106,15 @@ func TestTokenizeWithAndWithoutDropout(t *testing.T) { } expectedTokens = []models.Token{ - {ID: 0, Value: "u", Offsets: models.TokenOffsets{Start: 0, End: 1}}, - {ID: 1, Value: "n", Offsets: models.TokenOffsets{Start: 1, End: 2}}, - {ID: 2, Value: "r", Offsets: models.TokenOffsets{Start: 2, End: 3}}, - {ID: 3, Value: "e", Offsets: models.TokenOffsets{Start: 3, End: 4}}, - {ID: 4, Value: "l", Offsets: models.TokenOffsets{Start: 4, End: 5}}, - {ID: 5, Value: "a", Offsets: models.TokenOffsets{Start: 5, End: 6}}, - {ID: 6, Value: "t", Offsets: models.TokenOffsets{Start: 6, End: 7}}, - {ID: 3, Value: "e", Offsets: models.TokenOffsets{Start: 7, End: 8}}, - {ID: 7, Value: "d", Offsets: models.TokenOffsets{Start: 8, End: 9}}, + {ID: 0, Value: "u", Offsets: strutils.ByteOffsets{Start: 0, End: 1}}, + {ID: 1, Value: "n", Offsets: strutils.ByteOffsets{Start: 1, End: 2}}, + {ID: 2, Value: "r", Offsets: strutils.ByteOffsets{Start: 2, End: 3}}, + {ID: 3, Value: "e", Offsets: strutils.ByteOffsets{Start: 3, End: 4}}, + {ID: 4, Value: "l", Offsets: strutils.ByteOffsets{Start: 4, End: 5}}, + {ID: 5, Value: "a", Offsets: strutils.ByteOffsets{Start: 5, End: 6}}, + {ID: 6, Value: "t", Offsets: strutils.ByteOffsets{Start: 6, End: 7}}, + {ID: 3, Value: "e", Offsets: strutils.ByteOffsets{Start: 7, End: 8}}, + {ID: 7, Value: "d", Offsets: strutils.ByteOffsets{Start: 8, End: 9}}, } if !reflect.DeepEqual(tokens, expectedTokens) { t.Errorf("expected %+v, actual %+v", expectedTokens, tokens) diff --git a/models/models.go b/models/models.go index f24621b..c8a1b04 100644 --- a/models/models.go +++ b/models/models.go @@ -4,6 +4,8 @@ package models +import "github.com/nlpodyssey/gotokenizers/strutils" + // Model represents a model used during Tokenization (like BPE or Word or Unigram). type Model interface { // Tokenize tokenizes the given sequence into multiple underlying Tokens. @@ -14,10 +16,5 @@ type Model interface { type Token struct { ID int Value string - Offsets TokenOffsets -} - -type TokenOffsets struct { - Start int - End int + Offsets strutils.ByteOffsets } diff --git a/models/wordpiecemodel/wordpiecemodel.go b/models/wordpiecemodel/wordpiecemodel.go index 8d746bd..3dbba9f 100644 --- a/models/wordpiecemodel/wordpiecemodel.go +++ b/models/wordpiecemodel/wordpiecemodel.go @@ -7,6 +7,7 @@ package wordpiecemodel import ( "fmt" "github.com/nlpodyssey/gotokenizers/models" + "github.com/nlpodyssey/gotokenizers/strutils" "github.com/nlpodyssey/gotokenizers/vocabulary" ) @@ -60,7 +61,7 @@ func (m *WordPieceModel) Tokenize(sequence string) ([]models.Token, error) { return []models.Token{{ ID: unkTokenID, Value: m.unknownToken, - Offsets: models.TokenOffsets{Start: 0, End: len(sequence)}, + Offsets: strutils.ByteOffsets{Start: 0, End: len(sequence)}, }}, nil } @@ -85,7 +86,7 @@ func (m *WordPieceModel) Tokenize(sequence string) ([]models.Token, error) { curToken = models.Token{ ID: id, Value: subStr, - Offsets: models.TokenOffsets{Start: start, End: end}, + Offsets: strutils.ByteOffsets{Start: start, End: end}, } break } @@ -114,7 +115,7 @@ func (m *WordPieceModel) Tokenize(sequence string) ([]models.Token, error) { return []models.Token{{ ID: unkTokenID, Value: m.unknownToken, - Offsets: models.TokenOffsets{Start: 0, End: len(sequence)}, + Offsets: strutils.ByteOffsets{Start: 0, End: len(sequence)}, }}, nil } diff --git a/models/wordpiecemodel/wordpiecemodel_test.go b/models/wordpiecemodel/wordpiecemodel_test.go index 6fb1e8c..87ea6fa 100644 --- a/models/wordpiecemodel/wordpiecemodel_test.go +++ b/models/wordpiecemodel/wordpiecemodel_test.go @@ -7,6 +7,7 @@ package wordpiecemodel import ( "fmt" "github.com/nlpodyssey/gotokenizers/models" + "github.com/nlpodyssey/gotokenizers/strutils" "github.com/nlpodyssey/gotokenizers/vocabulary" "reflect" "testing" @@ -48,40 +49,40 @@ func TestWordPieceModelTokenize(t *testing.T) { { "foo", []models.Token{ - {ID: 1, Value: "foo", Offsets: models.TokenOffsets{Start: 0, End: 3}}, + {ID: 1, Value: "foo", Offsets: strutils.ByteOffsets{Start: 0, End: 3}}, }, }, { "barbaz", []models.Token{ - {ID: 3, Value: "bar", Offsets: models.TokenOffsets{Start: 0, End: 3}}, - {ID: 6, Value: "##baz", Offsets: models.TokenOffsets{Start: 3, End: 6}}, + {ID: 3, Value: "bar", Offsets: strutils.ByteOffsets{Start: 0, End: 3}}, + {ID: 6, Value: "##baz", Offsets: strutils.ByteOffsets{Start: 3, End: 6}}, }, }, { "alphabetagamma", []models.Token{ - {ID: 0, Value: "[UNK]", Offsets: models.TokenOffsets{Start: 0, End: 14}}, + {ID: 0, Value: "[UNK]", Offsets: strutils.ByteOffsets{Start: 0, End: 14}}, }, }, { "foobarbaz", []models.Token{ - {ID: 1, Value: "foo", Offsets: models.TokenOffsets{Start: 0, End: 3}}, - {ID: 4, Value: "##bar", Offsets: models.TokenOffsets{Start: 3, End: 6}}, - {ID: 6, Value: "##baz", Offsets: models.TokenOffsets{Start: 6, End: 9}}, + {ID: 1, Value: "foo", Offsets: strutils.ByteOffsets{Start: 0, End: 3}}, + {ID: 4, Value: "##bar", Offsets: strutils.ByteOffsets{Start: 3, End: 6}}, + {ID: 6, Value: "##baz", Offsets: strutils.ByteOffsets{Start: 6, End: 9}}, }, }, { "qux", []models.Token{ - {ID: 0, Value: "[UNK]", Offsets: models.TokenOffsets{Start: 0, End: 3}}, + {ID: 0, Value: "[UNK]", Offsets: strutils.ByteOffsets{Start: 0, End: 3}}, }, }, { "veryverylongterm", []models.Token{ - {ID: 0, Value: "[UNK]", Offsets: models.TokenOffsets{Start: 0, End: 16}}, + {ID: 0, Value: "[UNK]", Offsets: strutils.ByteOffsets{Start: 0, End: 16}}, }, }, }