From: Taku Kudo Date: Sun, 19 Jun 2022 15:55:46 +0000 (+0900) Subject: Added ImmutableSentencePiece class X-Git-Tag: archive/raspbian/0.1.97-3+rpi1^2~18 X-Git-Url: https://dgit.raspbian.org/?a=commitdiff_plain;h=bf8cc3e14d70a1ae7e0cdd4964144d7169888970;p=sentencepiece.git Added ImmutableSentencePiece class Signed-off-by: Kentaro Hayashi Gbp-Pq: Name 0010-Added-ImmutableSentencePiece-class.patch --- diff --git a/src/bpe_model.cc b/src/bpe_model.cc index 22cd115..bc7ada1 100644 --- a/src/bpe_model.cc +++ b/src/bpe_model.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "bpe_model.h" + #include #include #include @@ -19,7 +21,6 @@ #include #include -#include "bpe_model.h" #include "freelist.h" #include "third_party/absl/container/flat_hash_map.h" #include "util.h" @@ -71,8 +72,7 @@ std::vector> Model::SampleEncode( // Reverse merge rules. // key: merged symbol, value: pair of original symbols. absl::flat_hash_map, - string_util::string_view_hash> + std::pair> rev_merge; // Pre-allocates SymbolPair for efficiency. diff --git a/src/model_interface.h b/src/model_interface.h index 06b3a65..06e9243 100644 --- a/src/model_interface.h +++ b/src/model_interface.h @@ -53,8 +53,8 @@ class ModelProto; // Given a normalized string, returns a sequence of sentence pieces with ids. class ModelInterface { public: - using PieceToIdMap = absl::flat_hash_map; + using PieceToIdMap = absl::flat_hash_map; + // string_util::string_view_hash>; absl::string_view unk_piece() const; absl::string_view bos_piece() const; @@ -77,19 +77,6 @@ class ModelInterface { return matcher_.get(); } - // Sets the encoder version. Currently only unigram has an optimized encoder. - // The optimized version is always used by default if there is one, so - // normally users do not need to call this function. This function is provided - // just in case that a user want to manually choose which encoder version to - // use. - virtual util::Status SetEncoderVersion(EncoderVersion encoder_version) { - encoder_version_ = encoder_version; - return util::OkStatus(); - } - - // Returns the current encoder version in use. - virtual EncoderVersion GetEncoderVersion() const { return encoder_version_; } - // Given a normalized string, returns a sequence of sentence pieces with ids. // The concatenation of pieces must be the same as `normalized`. virtual EncodeResult Encode(absl::string_view normalized) const = 0; @@ -123,10 +110,9 @@ class ModelInterface { } // Calculates the entropy of the segmentation lattice with inverse temperature - // `theta`. - // Uses a novel dynamic program to calculate the entropy. + // `alpha`. Uses a novel dynamic program to calculate the entropy. virtual float CalculateEntropy(absl::string_view normalized, - float theta) const { + float alpha) const { LOG(ERROR) << "Not implemented."; return 0.0; } @@ -256,10 +242,6 @@ class ModelInterface { // unknown id. int unk_id_ = 0; - // The encoder version. Currently it is only effective for unigram model but - // ignored by other models. - EncoderVersion encoder_version_ = EncoderVersion::kOptimized; - // status. util::Status status_; }; diff --git a/src/model_interface_test.cc b/src/model_interface_test.cc index 69ee4e6..09e41d3 100644 --- a/src/model_interface_test.cc +++ b/src/model_interface_test.cc @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "model_factory.h" #include "model_interface.h" + +#include "model_factory.h" #include "testharness.h" #include "third_party/absl/container/flat_hash_map.h" #include "util.h" @@ -481,22 +482,6 @@ TEST(ModelInterfaceTest, PieceToByteTest) { EXPECT_EQ(PieceToByte("a"), -1); } -TEST(ModelInterfaceTest, SetEncoderVersion) { - for (const auto type : kModelTypes) { - ModelProto model_proto = MakeBaseModelProto(type); - AddPiece(&model_proto, "a"); - AddPiece(&model_proto, "b"); - auto model = ModelFactory::Create(model_proto); - - // Verify the default encoder version. - EXPECT_EQ(EncoderVersion::kOptimized, model->GetEncoderVersion()); - - // Set the encoder version to original and verify. - EXPECT_TRUE(model->SetEncoderVersion(EncoderVersion::kOriginal).ok()); - EXPECT_EQ(EncoderVersion::kOriginal, model->GetEncoderVersion()); - } -} - TEST(ModelInterfaceTest, VerifyOutputsEquivalent) { for (const auto type : kModelTypes) { ModelProto model_proto = MakeBaseModelProto(type); diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 331fc90..a6f5395 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -56,6 +56,112 @@ std::vector ToPieceArray(const std::vector &v) { } } // namespace +ImmutableSentencePieceText::ImmutableSentencePieceText() {} +ImmutableSentencePieceText::~ImmutableSentencePieceText() {} + +ImmutableSentencePieceText::ImmutableSentencePieceText( + const SentencePieceText &spt) + : spt_(&spt) {} + +ImmutableSentencePieceText::ImmutableSentencePiece::ImmutableSentencePiece( + const SentencePieceText_SentencePiece &sp) + : sp_(&sp) {} + +absl::string_view ImmutableSentencePieceText::ImmutableSentencePiece::piece() + const { + return sp_->piece(); +} + +absl::string_view ImmutableSentencePieceText::ImmutableSentencePiece::surface() + const { + return sp_->surface(); +} + +uint32_t ImmutableSentencePieceText::ImmutableSentencePiece::id() const { + return sp_->id(); +} + +uint32_t ImmutableSentencePieceText::ImmutableSentencePiece::begin() const { + return sp_->begin(); +} + +uint32_t ImmutableSentencePieceText::ImmutableSentencePiece::end() const { + return sp_->end(); +} + +std::vector +ImmutableSentencePieceText::pieces() const { + std::vector pieces; + if (spt_ == nullptr) return pieces; + pieces.reserve(spt_->pieces_size()); + for (int i = 0; i < spt_->pieces_size(); ++i) + pieces[i] = ImmutableSentencePiece(spt_->pieces(i)); + return pieces; +} + +size_t ImmutableSentencePieceText::pieces_size() const { + return spt_ ? spt_->pieces_size() : 0; +} + +ImmutableSentencePieceText::ImmutableSentencePiece +ImmutableSentencePieceText::pieces(int index) const { + return ImmutableSentencePieceText::ImmutableSentencePiece( + spt_->pieces(index)); +} + +absl::string_view ImmutableSentencePieceText::text() const { + return spt_ ? spt_->text() : ""; +} + +float ImmutableSentencePieceText::score() const { + return spt_ ? spt_->score() : 0.0; +} + +SentencePieceText *ImmutableSentencePieceText::mutable_proto() { + if (rep_ == nullptr) { + rep_ = std::make_shared(); + spt_ = rep_.get(); + } + return rep_.get(); +} + +std::string ImmutableSentencePieceText::SerializeAsString() const { + return spt_ ? spt_->SerializeAsString() : ""; +} + +ImmutableNBestSentencePieceText::ImmutableNBestSentencePieceText() {} +ImmutableNBestSentencePieceText::~ImmutableNBestSentencePieceText() {} + +size_t ImmutableNBestSentencePieceText::nbests_size() const { + return rep_ ? rep_->nbests_size() : 0; +} + +ImmutableSentencePieceText ImmutableNBestSentencePieceText::nbests( + int index) const { + return ImmutableSentencePieceText(rep_->nbests(index)); +} + +std::vector +ImmutableNBestSentencePieceText::nbests() const { + std::vector nbests; + if (rep_ == nullptr) return nbests; + nbests.reserve(rep_->nbests_size()); + for (int i = 0; i < rep_->nbests_size(); ++i) + nbests[i] = ImmutableSentencePieceText(rep_->nbests(i)); + return nbests; +} + +NBestSentencePieceText *ImmutableNBestSentencePieceText::mutable_proto() { + if (rep_ == nullptr) { + rep_ = std::make_shared(); + } + return rep_.get(); +} + +std::string ImmutableNBestSentencePieceText::SerializeAsString() const { + return rep_ ? rep_->SerializeAsString() : ""; +} + SentencePieceProcessor::SentencePieceProcessor() {} SentencePieceProcessor::~SentencePieceProcessor() {} @@ -124,15 +230,6 @@ util::Status SentencePieceProcessor::Load( return util::OkStatus(); } -util::Status SentencePieceProcessor::SetEncoderVersion( - EncoderVersion encoder_version) { - return model_->SetEncoderVersion(encoder_version); -} - -EncoderVersion SentencePieceProcessor::GetEncoderVersion() const { - return model_->GetEncoderVersion(); -} - util::Status SentencePieceProcessor::SetEncodeExtraOptions( absl::string_view extra_options) { return ParseExtraOptions(extra_options, &encode_extra_options_); @@ -348,14 +445,14 @@ util::Status SentencePieceProcessor::SampleEncode(absl::string_view input, } util::Status SentencePieceProcessor::SampleEncodeAndScore( - absl::string_view input, int num_samples, float theta, bool wor, + absl::string_view input, int num_samples, float alpha, bool wor, bool include_best, std::vector, float>> *pieces) const { CHECK_OR_RETURN_STATUS_STL(pieces); NBestSentencePieceText spt; RETURN_IF_ERROR( - SampleEncodeAndScore(input, num_samples, theta, wor, include_best, &spt)); + SampleEncodeAndScore(input, num_samples, alpha, wor, include_best, &spt)); pieces->clear(); pieces->reserve(spt.nbests_size()); @@ -373,14 +470,14 @@ util::Status SentencePieceProcessor::SampleEncodeAndScore( } util::Status SentencePieceProcessor::SampleEncodeAndScore( - absl::string_view input, int num_samples, float theta, bool wor, + absl::string_view input, int num_samples, float alpha, bool wor, bool include_best, std::vector, float>> *ids) const { CHECK_OR_RETURN_STATUS_STL(ids); NBestSentencePieceText spt; RETURN_IF_ERROR( - SampleEncodeAndScore(input, num_samples, theta, wor, include_best, &spt)); + SampleEncodeAndScore(input, num_samples, alpha, wor, include_best, &spt)); ids->clear(); ids->reserve(spt.nbests_size()); @@ -568,7 +665,7 @@ util::Status SentencePieceProcessor::SampleEncode( } util::Status SentencePieceProcessor::SampleEncodeAndScore( - absl::string_view input, int samples, float theta, bool wor, + absl::string_view input, int samples, float alpha, bool wor, bool include_best, NBestSentencePieceText *samples_spt) const { CHECK_OR_RETURN(model_->IsSampleEncodeAndScoreAvailable()) << "SampleEncodeAndScore is not available for the current model."; @@ -576,7 +673,7 @@ util::Status SentencePieceProcessor::SampleEncodeAndScore( std::vector norm_to_orig; RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig)); - const auto results = model_->SampleEncodeAndScore(normalized, theta, samples, + const auto results = model_->SampleEncodeAndScore(normalized, alpha, samples, wor, include_best); CHECK_OR_RETURN(!results.empty()) << "SampleEncodeAndScore returns empty result."; @@ -592,7 +689,7 @@ util::Status SentencePieceProcessor::SampleEncodeAndScore( } util::Status SentencePieceProcessor::CalculateEntropy(absl::string_view input, - float theta, + float alpha, float *entropy) const { CHECK_OR_RETURN(model_->IsCalculateEntropyAvailable()) << "CalculateEntropy is not available for the current model."; @@ -600,7 +697,7 @@ util::Status SentencePieceProcessor::CalculateEntropy(absl::string_view input, std::vector norm_to_orig; RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig)); - *entropy = model_->CalculateEntropy(normalized, theta); + *entropy = model_->CalculateEntropy(normalized, alpha); return util::OkStatus(); } @@ -770,48 +867,6 @@ util::Status SentencePieceProcessor::Decode(const std::vector &ids, return Decode(pieces, spt); } -std::string SentencePieceProcessor::EncodeAsSerializedProto( - absl::string_view input) const { - SentencePieceText spt; - if (!Encode(input, &spt).ok()) return ""; - return spt.SerializeAsString(); -} - -std::string SentencePieceProcessor::SampleEncodeAsSerializedProto( - absl::string_view input, int nbest_size, float alpha) const { - SentencePieceText spt; - if (!SampleEncode(input, nbest_size, alpha, &spt).ok()) return ""; - return spt.SerializeAsString(); -} - -std::string SentencePieceProcessor::NBestEncodeAsSerializedProto( - absl::string_view input, int nbest_size) const { - NBestSentencePieceText spt; - if (!NBestEncode(input, nbest_size, &spt).ok()) return ""; - return spt.SerializeAsString(); -} - -std::string SentencePieceProcessor::DecodePiecesAsSerializedProto( - const std::vector &pieces) const { - SentencePieceText spt; - if (!Decode(pieces, &spt).ok()) return ""; - return spt.SerializeAsString(); -} - -std::string SentencePieceProcessor::DecodePiecesAsSerializedProto( - const std::vector &pieces) const { - SentencePieceText spt; - if (!Decode(pieces, &spt).ok()) return ""; - return spt.SerializeAsString(); -} - -std::string SentencePieceProcessor::DecodeIdsAsSerializedProto( - const std::vector &ids) const { - SentencePieceText spt; - if (!Decode(ids, &spt).ok()) return ""; - return spt.SerializeAsString(); -} - #define CHECK_STATUS_OR_RETURN_DEFAULT(value) \ if (!status().ok()) { \ LOG(ERROR) << status().message() << "\nReturns default value " << value; \ diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index 8c72656..51c5b3b 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -29,11 +29,6 @@ using std::string_view; #endif // SWIG namespace sentencepiece { - -#ifndef SWIG -using EncodeResult = std::vector>; -#endif // SWIG - namespace util { enum class StatusCode : int { @@ -107,17 +102,17 @@ class Status { // sp.Load("//path/to/model"); // // vector sps; -// sp.Encode("hello world.", &sps); +// sp.Encode("hello world.", &sps).IgnoreError(); // // vector ids; -// sp.Encode("hello world.", &ids); +// sp.Encode("hello world.", &ids).IgnoreError(); // // string detok; // sp.Decode(sps, &detok); -// CHECK_EQ("hello world.", detok); +// CHECK_EQ("hello world.", detok).IgnoreError(); // // sp.Decode(ids, &detok); -// CHECK_EQ("hello world.", detok); +// CHECK_EQ("hello world.", detok).IgnoreError(); // // We can also use SentencePieceText which manages the byte-offsets // between user input (output) and internal sentence pieces. @@ -144,16 +139,6 @@ namespace normalizer { class Normalizer; } // namespace normalizer -#ifndef SWIG -// Defines the multiple versions of encoder within each model. Currently only -// the Unigram model has an optimized encoder. -enum class EncoderVersion { - kOptimized, // The optimized encoder (default). - kOriginal // The original encoder (user may choose to fall back to this - // just in case). -}; -#endif - #ifndef SWIGGO namespace util { // Redefine std::string for serialized_proto interface as Python's string is @@ -161,7 +146,87 @@ namespace util { // with SWIG's typemap. using bytes = std::string; } // namespace util -#endif +#endif // SWIGGO + +class NBestSentencePieceText; +class ModelInterface; +class SentencePieceText; +class SentencePieceText_SentencePiece; + +// Wrapper class of SentencePieceText +// This wrapper only allows an immutable access to the proto and +// hides the actual implementation of protobuf. +// See sentencepiece.proto for the details of this class. +class ImmutableSentencePieceText { + public: + ImmutableSentencePieceText(); + virtual ~ImmutableSentencePieceText(); + + class ImmutableSentencePiece { + public: + ~ImmutableSentencePiece() = default; + absl::string_view piece() const; + absl::string_view surface() const; + uint32_t id() const; + uint32_t begin() const; + uint32_t end() const; + + friend class ImmutableSentencePieceText; + + private: + ImmutableSentencePiece() = default; + explicit ImmutableSentencePiece(const SentencePieceText_SentencePiece &sp); + const SentencePieceText_SentencePiece *sp_ = nullptr; + }; + + std::vector pieces() const; + size_t pieces_size() const; + ImmutableSentencePiece pieces(int index) const; + absl::string_view text() const; + float score() const; + + std::string SerializeAsString() const; + + // Returns the actual mutable proto. + // Do not use this outside of SentencePieceProcessor, as + // it returns the raw pointer managed by the shared_ptr. + SentencePieceText *mutable_proto(); + + friend class ImmutableNBestSentencePieceText; + friend class SentencePieceProcessor; + + private: + explicit ImmutableSentencePieceText(const SentencePieceText &spt); + const SentencePieceText *spt_ = nullptr; + std::shared_ptr rep_; +}; + +// Wrapper class of SentencePieceText +// This wrapper only allows an immutable access to the proto and +// hides the actual implementation of protobuf. +// See sentencepiece.proto for the details of this class. +class ImmutableNBestSentencePieceText { + public: + ImmutableNBestSentencePieceText(); + virtual ~ImmutableNBestSentencePieceText(); + + std::vector nbests() const; + + size_t nbests_size() const; + ImmutableSentencePieceText nbests(int index) const; + + std::string SerializeAsString() const; + + // Returns the actual mutable proto. + // Do not use this outside of SentencePieceProcessor, as + // it returns the raw pointer managed by the shared_ptr. + NBestSentencePieceText *mutable_proto(); + + friend class SentencePieceProcessor; + + private: + std::shared_ptr rep_; +}; class SentencePieceProcessor { public: @@ -217,7 +282,7 @@ class SentencePieceProcessor { int threshold); ////////////////////////////////////////////////////////////// - // Simple API. + // Simple Encode and Decode API. // // Given a UTF8 input, encodes it into a sequence of sentence pieces. virtual util::Status Encode(absl::string_view input, @@ -239,18 +304,9 @@ class SentencePieceProcessor { virtual util::Status Decode(const std::vector &ids, std::string *detokenized) const; -#ifndef SWIG - // Sets the encoder version. Normally users do not need to call this function. - // But they can call this fucntion just in case if they want to fall back to - // the original encoder. - virtual util::Status SetEncoderVersion(EncoderVersion encoder_version); - - // Returns the current encoder version in use. - virtual EncoderVersion GetEncoderVersion() const; -#endif - ////////////////////////////////////////////////////////////// // NBest API. + // // Same as Encode, but returns nbest results. virtual util::Status NBestEncode( absl::string_view input, int nbest_size, @@ -262,24 +318,24 @@ class SentencePieceProcessor { ////////////////////////////////////////////////////////////// // Sampling API. + // // Unigram and BPE support sampling mode. // - Unigram (--model_type=unigram): - // When `nbest_size` is positive value, approximately samples one - // segmentation from nbest candidates. When `nbest_size` is negative value, - // samples one segmentation from the hypotheses (Lattice) according to the - // generation probabilities using forward-filtering and backward-sampling - // algorithm. `alpha` is a smoothing parameter. The best segmentation - // (Viterbi segmentation) is more likely sampled when setting larger - // alpha. When alpha is 0.0, one segmentation is uniformly sampled from the - // nbest or lattice. - // `nbest_size` and `alpha` correspond to parameters `l` and `alpha` + // `nbest_size`: When `nbest_size` is positive value, approximately samples + // one segmentation from nbest candidates. When `nbest_size` is negative + // value, samples one segmentation from the hypotheses (Lattice) according to + // the generation probabilities using forward-filtering and backward-sampling + // algorithm. + // `alpha`: Smoothing parameter (inverse temperature). The best segmentation + // (Viterbi segmentation) is more likely sampled when setting larger alpha. + // When alpha is 0.0, one segmentation is uniformly sampled from the nbest or + // lattice. `nbest_size` and `alpha` correspond to parameters `l` and `alpha` // in https://arxiv.org/abs/1804.10959 (nbest_size < 0 means l = infinity) // // - BPE (--model_type=bpe): - // `alpha` is the dropout probability `p` of bpe merge operations - // in https://arxiv.org/abs/1910.13267 - // Nbest-based sampling is not supported so nbest_size parameter is ignored in - // BPE. + // `alpha`: The dropout probability `p` of bpe merge operations in + // https://arxiv.org/abs/1910.13267 Nbest-based sampling is not supported so + // nbest_size parameter is ignored in BPE. virtual util::Status SampleEncode(absl::string_view input, int nbest_size, float alpha, std::vector *pieces) const; @@ -290,74 +346,104 @@ class SentencePieceProcessor { ////////////////////////////////////////////////////////////// // SampleEncodeAndScore API. - // Similar to SampleEncode, but returns samples results. + // + // Sample `samples` many tokenisations from the segmentation lattice. + // These methods are only available in model_type=unigram. + // + // `alpha`: smoothing parameter (inverse temperature). The same as `alpha` in + // `Sample` method. + // 'wor`: If `wor` is true, the samples are taken without replacement, and the + // scores are the inclusion probabilities of the elements in the sample; + // otherwise the samples are taken with replacement and the scores are the + // log-probs of sample elements + // `include_best`: If `include_best` is true, the best tokenisation is always + // included in the sample, and the remaining elements are sampled excluding + // the best. virtual util::Status SampleEncodeAndScore( - absl::string_view input, int num_samples, float theta, bool wor, + absl::string_view input, int num_samples, float alpha, bool wor, bool include_best, std::vector, float>> *pieces) const; // Same as above, but returns a sequence of ids. virtual util::Status SampleEncodeAndScore( - absl::string_view input, int num_samples, float theta, bool wor, + absl::string_view input, int num_samples, float alpha, bool wor, bool include_best, std::vector, float>> *ids) const; + ////////////////////////////////////////////////////////////// + // Entropy API. + // + // This only available in model_type=unigram. + // Calculate entropy of possible tokenisations + virtual util::Status CalculateEntropy(absl::string_view input, float alpha, + float *entropy) const; + ////////////////////////////////////////////////////////////// // Advanced API returning SentencePieceText, which manages // utf8-byte alignments between user-input/detokenized text // and internal sentencepiece sequence. // // Given a UTF8 input, encodes it into SentencePieceText. + // + // When using these APIs, sentencepiece.pb.h header files must be included. + // We can also use ImutableSentencePieceText as follows. + // + // ImmutableSentencePieceText spt; + // Encode("hello", spt.mutable_proto()).IgnoreError(); + // std::cout << spt.pieces_size() << std::endl; virtual util::Status Encode(absl::string_view input, SentencePieceText *spt) const; - // Same as above, but returns NBestSentencePieceText. virtual util::Status NBestEncode(absl::string_view input, int nbest_size, NBestSentencePieceText *nbest_spt) const; - // Same as above, but samples one segmentation from the hypotheses - // (Lattice). virtual util::Status SampleEncode(absl::string_view input, int nbest_size, float alpha, SentencePieceText *spt) const; - // Samples N segmentation and returns the scores as well virtual util::Status SampleEncodeAndScore( - absl::string_view input, int samples, float theta, bool wor, + absl::string_view input, int samples, float alpha, bool wor, bool include_best, NBestSentencePieceText *samples_spt) const; - // Calculate entropy of possible tokenisations - virtual util::Status CalculateEntropy(absl::string_view input, float theta, - float *entropy) const; - - // Given a sequence of pieces, decodes it into SentencePieceText. - // TODO(taku): Remove this API and use std::vector + // DEPRECATED: Remove this API and use std::vector virtual util::Status Decode(const std::vector &pieces, SentencePieceText *spt) const; - // Given a sequence of pieces, decodes it into SentencePieceText. virtual util::Status Decode(const std::vector &pieces, SentencePieceText *spt) const; - // Given a sequence of ids, decodes it into SentencePieceText. virtual util::Status Decode(const std::vector &ids, SentencePieceText *spt) const; - ////////////////////////////////////////////////////////////// - // Handy methods that return the result directly. - // These functions ignore internal errors. #ifdef SWIG -#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \ - OutType output; \ - const auto _status = FuncName(__VA_ARGS__, &output); \ - if (!_status.ok()) throw _status; \ - return output; +#define SPP_SWIG_CHECK_AND_THROW \ + if (!status.ok()) throw status; #else +#define SPP_SWIG_CHECK_AND_THROW \ + if (!status.ok()) { \ + } +#endif // SWIG + #define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \ OutType output; \ - FuncName(__VA_ARGS__, &output).IgnoreError(); \ + const auto status = FuncName(__VA_ARGS__, &output); \ + SPP_SWIG_CHECK_AND_THROW; \ + return output; + +#define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...) \ + OutType output; \ + const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \ + SPP_SWIG_CHECK_AND_THROW; \ + return output.SerializeAsString(); + +#define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...) \ + OutType output; \ + const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \ + SPP_SWIG_CHECK_AND_THROW; \ return output; -#endif + ////////////////////////////////////////////////////////////// + // Handy methods that return the result directly. + // These functions ignore internal errors. virtual std::vector EncodeAsPieces( absl::string_view input) const { DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector, input); @@ -395,21 +481,21 @@ class SentencePieceProcessor { virtual std::vector, float>> SampleEncodeAndScoreAsPieces(absl::string_view input, int num_samples, - float theta, bool wor, bool include_best) const { + float alpha, bool wor, bool include_best) const { using _T = std::vector, float>>; DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples, - theta, wor, include_best); + alpha, wor, include_best); } virtual std::vector, float>> SampleEncodeAndScoreAsIds(absl::string_view input, int num_samples, - float theta, bool wor, bool include_best) const { + float alpha, bool wor, bool include_best) const { using _T = std::vector, float>>; DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples, - theta, wor, include_best); + alpha, wor, include_best); } - // TODO(taku): Remove this API and use std::vector + // DEPRECATED: Remove this API and use std::vector virtual std::string DecodePieces( const std::vector &pieces) const { DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces); @@ -424,33 +510,104 @@ class SentencePieceProcessor { DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, ids); } - virtual float CalculateEntropy(absl::string_view text, float theta) const { - DEFINE_SPP_DIRECT_FUNC_IMPL(CalculateEntropy, float, text, theta); + virtual float CalculateEntropy(absl::string_view text, float alpha) const { + DEFINE_SPP_DIRECT_FUNC_IMPL(CalculateEntropy, float, text, alpha); } -#undef DEFINE_SPP_DIRECT_FUNC_IMPL - + ////////////////////////////////////////////////////////////// + // SerializedProto API. (DEPRECATED). Use ImmutableProto API. // They are used in Python interface. Returns serialized proto. // In python module, we can get access to the full Proto after // deserialzing the returned byte sequence. - virtual util::bytes EncodeAsSerializedProto(absl::string_view input) const; + virtual util::bytes EncodeAsSerializedProto(absl::string_view input) const { + DEFINE_SPP_SERIALIZED_PROTO_IMPL(Encode, ImmutableSentencePieceText, input); + } virtual util::bytes SampleEncodeAsSerializedProto(absl::string_view input, int nbest_size, - float alpha) const; + float alpha) const { + DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText, + input, nbest_size, alpha); + } virtual util::bytes NBestEncodeAsSerializedProto(absl::string_view input, - int nbest_size) const; + int nbest_size) const { + DEFINE_SPP_SERIALIZED_PROTO_IMPL( + NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size); + } + + virtual util::bytes SampleEncodeAndScoreAsSerializedProto( + absl::string_view input, int samples, float alpha, bool wor, + bool include_best, int nbest_size) const { + DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncodeAndScore, + ImmutableNBestSentencePieceText, input, + samples, alpha, wor, include_best); + } // TODO(taku): Remove this API and use std::vector virtual util::bytes DecodePiecesAsSerializedProto( - const std::vector &pieces) const; + const std::vector &pieces) const { + DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText, + pieces); + } virtual util::bytes DecodePiecesAsSerializedProto( - const std::vector &pieces) const; + const std::vector &pieces) const { + DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText, + pieces); + } virtual util::bytes DecodeIdsAsSerializedProto( - const std::vector &ids) const; + const std::vector &ids) const { + DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids); + } + + ////////////////////////////////////////////////////////////// + // ImmutableProto API. + virtual ImmutableSentencePieceText EncodeAsImmutableProto( + absl::string_view input) const { + DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Encode, ImmutableSentencePieceText, input); + } + + virtual ImmutableSentencePieceText SampleEncodeAsImmutableProto( + absl::string_view input, int nbest_size, float alpha) const { + DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText, + input, nbest_size, alpha); + } + + virtual ImmutableNBestSentencePieceText NBestEncodeAsImmutableProto( + absl::string_view input, int nbest_size) const { + DEFINE_SPP_IMMUTABLE_PROTO_IMPL( + NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size); + } + + virtual ImmutableNBestSentencePieceText SampleEncodeAndScoreAsImmutableProto( + absl::string_view input, int samples, float alpha, bool wor, + bool include_best, int nbest_size) const { + DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncodeAndScore, + ImmutableNBestSentencePieceText, input, + samples, alpha, wor, include_best); + } + + // TODO(taku): Remove this API and use std::vector + virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto( + const std::vector &pieces) const { + DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces); + } + + virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto( + const std::vector &pieces) const { + DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces); + } + + virtual ImmutableSentencePieceText DecodeIdsAsImmutableProto( + const std::vector &ids) const { + DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids); + } + +#undef DEFINE_SPP_DIRECT_FUNC_IMPL +#undef DEFINE_SPP_SERIALIZED_PROTO_IMPL +#undef DEFINE_SPP_IMMUTABLE_PROTO_IMPL ////////////////////////////////////////////////////////////// // Vocabulary management methods. @@ -467,7 +624,8 @@ class SentencePieceProcessor { virtual const std::string &IdToPiece(int id) const; // Returns the score of `id`. - // Usually score is an emission log probability of unigram language model. + // Usually score is an emission log probability of unigram language + // model. virtual float GetScore(int id) const; // Returns true if `id` is unknown symbol. @@ -506,7 +664,7 @@ class SentencePieceProcessor { // Allows injection of a normalizer instance. `normalizer` is moved. void SetNormalizer(std::unique_ptr &&normalizer); -#endif +#endif // SWIG // Returns immutable model proto. Useful to obtain extended // or experimental parameters encoded in model_proto. diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index d57ab5a..ed651f7 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "sentencepiece_processor.h" + #include #include "builder.h" @@ -20,7 +22,6 @@ #include "normalizer.h" #include "sentencepiece.pb.h" #include "sentencepiece_model.pb.h" -#include "sentencepiece_processor.h" #include "sentencepiece_trainer.h" #include "testharness.h" #include "third_party/absl/container/flat_hash_map.h" @@ -551,10 +552,9 @@ TEST(SentencepieceProcessorTest, DecodeTest) { int GetPieceSize() const override { return 7; } int PieceToId(absl::string_view piece) const override { - static absl::flat_hash_map - kMap = {{"", 0}, {"", 1}, {"", 2}, {WS "ABC", 3}, - {WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}}; + static absl::flat_hash_map kMap = { + {"", 0}, {"", 1}, {"", 2}, {WS "ABC", 3}, + {WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}}; return port::FindWithDefault(kMap, piece, 0); } @@ -719,10 +719,9 @@ TEST(SentencepieceProcessorTest, DummyPrefixDecodeTest) { int GetPieceSize() const override { return 7; } int PieceToId(absl::string_view piece) const override { - static absl::flat_hash_map - kMap = {{"", 0}, {"", 1}, {"", 2}, {WS "ABC", 3}, - {WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}, {WS, 7}}; + static absl::flat_hash_map kMap = { + {"", 0}, {"", 1}, {"", 2}, {WS "ABC", 3}, + {WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}, {WS, 7}}; return port::FindWithDefault(kMap, piece, 0); } @@ -1058,18 +1057,6 @@ TEST(SentencePieceProcessorTest, EndToEndTest) { EXPECT_EQ(2, sp.eos_id()); EXPECT_EQ(-1, sp.pad_id()); - { - // Verify the default encoder version. - EXPECT_EQ(EncoderVersion::kOptimized, sp.GetEncoderVersion()); - - // Set the encoder version to original and verify. - EXPECT_TRUE(sp.SetEncoderVersion(EncoderVersion::kOriginal).ok()); - EXPECT_EQ(EncoderVersion::kOriginal, sp.GetEncoderVersion()); - - // Set back to the default encoder version. - EXPECT_TRUE(sp.SetEncoderVersion(EncoderVersion::kOptimized).ok()); - } - { std::vector sps; const std::vector expected_str = {WS, "ab", "c"}; @@ -1574,4 +1561,77 @@ TEST(SentencePieceProcessorTest, VocabularyTest) { EXPECT_FALSE(sp.IsUnused(6)); EXPECT_FALSE(sp.IsUnused(7)); } + +TEST(SentencePieceProcessorTest, ImmutableSentencePieceTextTest) { + ImmutableSentencePieceText spt; + auto *v = spt.mutable_proto(); + + v->set_text("hello world"); + v->set_score(1.0); + for (int i = 0; i < 10; ++i) { + auto *p = v->add_pieces(); + p->set_surface(absl::StrCat("surface_", i)); + p->set_piece(absl::StrCat("surface_", i)); + p->set_id(i); + p->set_begin(i + 10); + p->set_end(i + 20); + } + + EXPECT_EQ(v->pieces_size(), spt.pieces_size()); + for (int i = 0; i < spt.pieces_size(); ++i) { + EXPECT_EQ(v->pieces(i).surface(), spt.pieces(i).surface()); + EXPECT_EQ(v->pieces(i).piece(), spt.pieces(i).piece()); + EXPECT_EQ(v->pieces(i).id(), spt.pieces(i).id()); + EXPECT_EQ(v->pieces(i).begin(), spt.pieces(i).begin()); + EXPECT_EQ(v->pieces(i).end(), spt.pieces(i).end()); + } + + int n = 0; + for (auto &p : spt.pieces()) { + EXPECT_EQ(v->pieces(n).surface(), p.surface()); + EXPECT_EQ(v->pieces(n).piece(), p.piece()); + EXPECT_EQ(v->pieces(n).id(), p.id()); + EXPECT_EQ(v->pieces(n).begin(), p.begin()); + EXPECT_EQ(v->pieces(n).end(), p.end()); + ++n; + } + + EXPECT_EQ(v->text(), spt.text()); + EXPECT_EQ(v->score(), spt.score()); + EXPECT_EQ(v->SerializeAsString(), spt.SerializeAsString()); + + // test copy. + auto spt2 = spt; + EXPECT_EQ(spt2.pieces_size(), spt.pieces_size()); + for (int i = 0; i < spt.pieces_size(); ++i) { + EXPECT_EQ(spt2.pieces(i).surface(), spt.pieces(i).surface()); + EXPECT_EQ(spt2.pieces(i).piece(), spt.pieces(i).piece()); + EXPECT_EQ(spt2.pieces(i).id(), spt.pieces(i).id()); + EXPECT_EQ(spt2.pieces(i).begin(), spt.pieces(i).begin()); + EXPECT_EQ(spt2.pieces(i).end(), spt.pieces(i).end()); + } +} + +TEST(SentencePieceProcessorTest, ImmutableNBestSentencePieceTextTest) { + ImmutableNBestSentencePieceText spt; + auto *v = spt.mutable_proto(); + for (int i = 0; i < 10; ++i) { + auto *p = v->add_nbests(); + p->set_text(absl::StrCat("text_", i)); + p->set_score(2.0 * i); + } + + EXPECT_EQ(v->nbests_size(), spt.nbests_size()); + for (int i = 0; i < v->nbests_size(); ++i) { + EXPECT_EQ(v->nbests(i).text(), spt.nbests(i).text()); + EXPECT_EQ(v->nbests(i).score(), spt.nbests(i).score()); + } + EXPECT_EQ(v->SerializeAsString(), spt.SerializeAsString()); + + // test copy. + auto spt2 = spt; + EXPECT_EQ(spt2.nbests_size(), spt.nbests_size()); + EXPECT_EQ(spt2.SerializeAsString(), spt.SerializeAsString()); +} + } // namespace sentencepiece diff --git a/src/unigram_model.cc b/src/unigram_model.cc index ea48912..d9f1ce9 100644 --- a/src/unigram_model.cc +++ b/src/unigram_model.cc @@ -198,16 +198,17 @@ Lattice::LatticePathWithScore Lattice::Viterbi() { return retval; } -std::vector Lattice::ForwardAlgorithm(float theta) const { +std::vector Lattice::ForwardAlgorithm(float inv_theta) const { const int len = size(); std::vector alpha(node_allocator_.size(), 0.0); for (int pos = 0; pos <= len; ++pos) { for (Node *rnode : begin_nodes_[pos]) { for (Node *lnode : end_nodes_[pos]) { - alpha[rnode->node_id] = LogSumExp( - alpha[rnode->node_id], theta * lnode->score + alpha[lnode->node_id], - lnode == end_nodes_[pos][0]); + alpha[rnode->node_id] = + LogSumExp(alpha[rnode->node_id], + inv_theta * lnode->score + alpha[lnode->node_id], + lnode == end_nodes_[pos][0]); } } } @@ -215,7 +216,7 @@ std::vector Lattice::ForwardAlgorithm(float theta) const { return alpha; } -std::vector Lattice::BackwardAlgorithm(float theta) const { +std::vector Lattice::BackwardAlgorithm(float inv_theta) const { const int len = size(); std::vector beta(node_allocator_.size(), 0.0); @@ -260,17 +261,16 @@ float Lattice::PopulateMarginal(float freq, return freq * Z; } -float Lattice::CalculateEntropy(float theta) const { +float Lattice::CalculateEntropy(float inv_theta) const { const int len = size(); // alpha[node_id] is the marginal prob of sequence up to start of node // H is entropy of sequence // the index of alpha/H is Node::node_id. - std::vector alpha(node_allocator_.size(), 0.0); std::vector H(node_allocator_.size(), 0.0); // Populate the forward marginals to get the normalising constant - alpha = ForwardAlgorithm(theta); + const auto alpha = ForwardAlgorithm(inv_theta); // Now populate the forward entropies for (int pos = 0; pos <= len; ++pos) { @@ -280,7 +280,7 @@ float Lattice::CalculateEntropy(float theta) const { // We have to normalise p(lnode) by the marginal contribution it makes const float lnode_transition_prob = - ((theta * lnode->score) + alpha[lnode->node_id] - + ((inv_theta * lnode->score) + alpha[lnode->node_id] - alpha[rnode->node_id]); H[rnode->node_id] += std::exp(lnode_transition_prob) * (H[lnode->node_id] + lnode_transition_prob); @@ -345,7 +345,7 @@ Hypothesis *CloneHypAndDependents( std::vector Lattice::NBest(size_t nbest_size, bool sample, - float theta) { + float inv_theta) { if (nbest_size < 1) { LOG(WARNING) << "nbest_size >= 1. Returns empty result."; return {}; @@ -391,7 +391,7 @@ std::vector Lattice::NBest(size_t nbest_size, if (sample) { // Run forwards algorithm to get normalising constants - alpha = ForwardAlgorithm(theta); + alpha = ForwardAlgorithm(inv_theta); // f(eos) = Gumbel(0), as it is the perturbed score of the entire lattice. eos->fx = Gumbel(); } else { @@ -432,7 +432,8 @@ std::vector Lattice::NBest(size_t nbest_size, for (int i = 0; i < end_nodes(node->pos).size(); i++) { Node *lnode = end_nodes(node->pos)[i]; // Calculate backwards transition score - probs[i] = top->gx + alpha[lnode->node_id] + (theta * lnode->score) - Z; + probs[i] = + top->gx + alpha[lnode->node_id] + (inv_theta * lnode->score) - Z; perturbed_probs[i] = probs[i] + Gumbel(); if (perturbed_probs[i] > max_score) { max_score = perturbed_probs[i]; @@ -508,13 +509,13 @@ std::vector Lattice::NBest(size_t nbest_size, return results; } -std::vector Lattice::Sample(float theta) { +std::vector Lattice::Sample(float inv_theta) { const int len = size(); if (len == 0) return {}; std::vector alpha(node_allocator_.size(), 0.0); - alpha = ForwardAlgorithm(theta); + alpha = ForwardAlgorithm(inv_theta); auto *mt = random::GetRandomGenerator(); @@ -526,8 +527,8 @@ std::vector Lattice::Sample(float theta) { while (true) { probs.clear(); for (const Node *lnode : end_nodes_[node->pos]) { - probs.push_back(std::exp(static_cast(alpha[lnode->node_id] + - theta * lnode->score - Z))); + probs.push_back(std::exp(static_cast( + alpha[lnode->node_id] + inv_theta * lnode->score - Z))); } std::discrete_distribution dist(probs.begin(), probs.end()); node = end_nodes_[node->pos][dist(*mt)]; @@ -721,7 +722,7 @@ NBestEncodeResult Model::NBestEncode(absl::string_view normalized, } EncodeResult Model::SampleEncode(absl::string_view normalized, - float theta) const { + float inv_theta) const { if (!status().ok() || normalized.empty()) { return {}; } @@ -731,7 +732,7 @@ EncodeResult Model::SampleEncode(absl::string_view normalized, PopulateNodes(&lattice); EncodeResult results; - for (const auto *node : lattice.Sample(theta)) { + for (const auto *node : lattice.Sample(inv_theta)) { results.emplace_back(node->piece, node->id); } @@ -739,7 +740,7 @@ EncodeResult Model::SampleEncode(absl::string_view normalized, } NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized, - float theta, int samples, + float inv_theta, int samples, bool wor, bool include_best) const { if (!status().ok() || normalized.empty()) { @@ -750,16 +751,16 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized, lattice.SetSentence(normalized); PopulateNodes(&lattice); - std::vector alpha = lattice.ForwardAlgorithm(theta); - float marginal = alpha[lattice.eos_node()->node_id]; + const std::vector alpha = lattice.ForwardAlgorithm(inv_theta); + const float marginal = alpha[lattice.eos_node()->node_id]; if (include_best) { if (!wor) { - LOG(FATAL) << "include_best not supported for wor false"; + LOG(ERROR) << "include_best not supported for wor false"; + return {}; } EncodeResult result; - Lattice::LatticePathWithScore best_path = lattice.Viterbi(); - + const auto best_path = lattice.Viterbi(); for (const auto *node : best_path.first) { result.emplace_back(node->piece, node->id); } @@ -770,8 +771,7 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized, if (wor) { // Draw k+1 samples as we need perturbed score of k+1th element - std::vector nbest_samples = - lattice.NBest(samples + 1, true, theta); + auto nbest_samples = lattice.NBest(samples + 1, true, inv_theta); if (include_best) { std::vector> nbest_paths( @@ -780,14 +780,13 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized, nbest_paths[i] = nbest_samples[i].first; } // Remove the best result from the samples if necessary - Lattice::LatticePathWithScore best_path = lattice.Viterbi(); + const auto best_path = lattice.Viterbi(); const int index_of_best = (std::find(nbest_paths.begin(), nbest_paths.end(), best_path.first) - nbest_paths.begin()); if (index_of_best != nbest_samples.size()) { - LOG(INFO) << "removing best path from samples"; nbest_samples.erase(nbest_samples.begin() + index_of_best); } else { nbest_samples.pop_back(); @@ -803,7 +802,7 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized, float score = 0.0; for (const auto *node : nbest.first) { - score += (theta * node->score); + score += (inv_theta * node->score); result.emplace_back(node->piece, node->id); } @@ -814,8 +813,8 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized, for (auto &it : results) { // Only modify non best sample inclusion probabilities. if (it.second != 0.0) { - double x = it.second - kappa; - double y = std::exp(x); + const double x = it.second - kappa; + const double y = std::exp(x); double inclusion_prob; if (x <= -10) { // Series expansion of the log Gumbel survival function up to eps. @@ -835,10 +834,10 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized, float score = 0.0; EncodeResult result; - std::vector sample = lattice.Sample(theta); + const std::vector sample = lattice.Sample(inv_theta); for (const auto *node : sample) { result.emplace_back(node->piece, node->id); - score += (theta * node->score); + score += (inv_theta * node->score); } results.emplace_back(result, score - marginal); } @@ -847,12 +846,13 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized, return results; } -float Model::CalculateEntropy(absl::string_view normalized, float theta) const { +float Model::CalculateEntropy(absl::string_view normalized, + float inv_theta) const { Lattice lattice; lattice.SetSentence(normalized); PopulateNodes(&lattice); - return lattice.CalculateEntropy(theta); + return lattice.CalculateEntropy(inv_theta); } bool Model::VerifyOutputsEquivalent(absl::string_view expected, diff --git a/src/unigram_model.h b/src/unigram_model.h index 448e489..aa4f28f 100644 --- a/src/unigram_model.h +++ b/src/unigram_model.h @@ -173,6 +173,18 @@ class Model : public ModelInterface { bool VerifyOutputsEquivalent(absl::string_view expected, absl::string_view actual) const override; + enum EncoderVersion { + kOptimized, // The optimized encoder. + kOriginal // The original encoder. + }; + + void SetEncoderVersion(EncoderVersion encoder_version) { + encoder_version_ = encoder_version; + } + + // Returns the current encoder version in use. + EncoderVersion GetEncoderVersion() const { return encoder_version_; } + protected: // Builds a Trie index. void BuildTrie(std::vector> *pieces); @@ -195,6 +207,9 @@ class Model : public ModelInterface { // Maximum size of the return value of Trie, which corresponds // to the maximum size of shared common prefix in the sentence pieces. int trie_results_size_; + + // encoder version. + EncoderVersion encoder_version_ = kOptimized; }; } // namespace unigram diff --git a/src/unigram_model_test.cc b/src/unigram_model_test.cc index 8049d20..221bac2 100644 --- a/src/unigram_model_test.cc +++ b/src/unigram_model_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "unigram_model.h" + #include #include #include @@ -22,7 +24,6 @@ #include "testharness.h" #include "third_party/absl/strings/str_cat.h" #include "third_party/absl/strings/str_join.h" -#include "unigram_model.h" #include "util.h" namespace sentencepiece { @@ -249,14 +250,14 @@ TEST(LatticeTest, NBestSampleTest) { // Calculate expected probabilities of each path // Note that sampling without replacement affects the expected frequencies! - const std::vector kTheta = {0.0, 0.01, 0.5, 0.7, 1.0}; - for (const auto theta : kTheta) { + const std::vector kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0}; + for (const auto inv_theta : kInv_Theta) { std::vector strings = {"ABC", "AB C", "A BC", "A B C"}; std::map probs; - probs["ABC"] = std::exp(theta * 1.0); - probs["AB C"] = std::exp(theta * (0.2 + 0.1)); - probs["A BC"] = std::exp(theta * (0.0 + 0.5)); - probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1)); + probs["ABC"] = std::exp(inv_theta * 1.0); + probs["AB C"] = std::exp(inv_theta * (0.2 + 0.1)); + probs["A BC"] = std::exp(inv_theta * (0.0 + 0.5)); + probs["A B C"] = std::exp(inv_theta * (0.0 + 0.0 + 0.1)); for (const auto &it : strings) { EXPECT_EQ(1, probs.count(it)); @@ -298,7 +299,7 @@ TEST(LatticeTest, NBestSampleTest) { for (const auto num_samples : kNumSamples) { std::map counts; for (int i = 0; i < kTrials; i++) { - auto nbests = lattice.NBest(num_samples, true, theta); + auto nbests = lattice.NBest(num_samples, true, inv_theta); for (const auto &nbest : nbests) { counts[GetTokenized(nbest.first)]++; } @@ -329,14 +330,14 @@ TEST(LatticeTest, CalculateEntropyTest) { InsertWithScore(&lattice, 0, 3, 1.0); // ABC // Calculate expected probabilities of each path - const std::vector kTheta = {0.0, 0.01, 0.5, 0.7, 1.0}; - for (const auto theta : kTheta) { + const std::vector kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0}; + for (const auto inv_theta : kInv_Theta) { std::vector strings = {"ABC", "AB C", "A BC", "A B C"}; std::map probs; - probs["ABC"] = std::exp(theta * 1.0); - probs["AB C"] = std::exp(theta * (0.2 + 0.1)); - probs["A BC"] = std::exp(theta * (0.0 + 0.5)); - probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1)); + probs["ABC"] = std::exp(inv_theta * 1.0); + probs["AB C"] = std::exp(inv_theta * (0.2 + 0.1)); + probs["A BC"] = std::exp(inv_theta * (0.0 + 0.5)); + probs["A B C"] = std::exp(inv_theta * (0.0 + 0.0 + 0.1)); double Z = 0.0; for (const auto &it : probs) Z += it.second; @@ -349,7 +350,7 @@ TEST(LatticeTest, CalculateEntropyTest) { for (const auto &it : probs) { entropy += (it.second * std::log(it.second)); } - EXPECT_NEAR(-entropy, lattice.CalculateEntropy(theta), 0.02); + EXPECT_NEAR(-entropy, lattice.CalculateEntropy(inv_theta), 0.02); } } @@ -364,9 +365,9 @@ TEST(LatticeTest, ForwardAlgorithmTest) { InsertWithScore(&lattice, 1, 2, 0.5); // BC InsertWithScore(&lattice, 0, 3, 1.0); // ABC - const std::vector kTheta = {0.0, 0.01, 0.5, 0.7, 1.0}; - for (const auto theta : kTheta) { - std::vector alpha = lattice.ForwardAlgorithm(theta); + const std::vector kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0}; + for (const auto inv_theta : kInv_Theta) { + std::vector alpha = lattice.ForwardAlgorithm(inv_theta); EXPECT_EQ(alpha.size(), 8); // 6 nodes, plus BOS, EOS // only alpha[C], alpha[EOS] have non-zero alpha for (int i : {0, 1, 2, 3}) { @@ -374,14 +375,15 @@ TEST(LatticeTest, ForwardAlgorithmTest) { if (i < 2) { EXPECT_EQ(alpha[node->node_id], 0.0); } else if (i == 2) { - float Z = - std::log(std::exp(theta * (0.0 + 0.0)) + std::exp(theta * 0.2)); + float Z = std::log(std::exp(inv_theta * (0.0 + 0.0)) + + std::exp(inv_theta * 0.2)); EXPECT_EQ(alpha[node->node_id], Z); } else if (i == 3) { - float Z = std::log(std::exp(theta * (0.0 + 0.0 + 0.1)) + // A + B + C - std::exp(theta * (0.2 + 0.1)) + // AB + C - std::exp(theta * (0.0 + 0.5)) + // A + BC - std::exp(theta * 1.0)); // ABC + float Z = + std::log(std::exp(inv_theta * (0.0 + 0.0 + 0.1)) + // A + B + C + std::exp(inv_theta * (0.2 + 0.1)) + // AB + C + std::exp(inv_theta * (0.0 + 0.5)) + // A + BC + std::exp(inv_theta * 1.0)); // ABC EXPECT_EQ(Z, alpha[node->node_id]); } } @@ -435,14 +437,14 @@ TEST(LatticeTest, SampleTest) { InsertWithScoreAndId(&lattice, 1, 2, 1.7, 4); // BC InsertWithScoreAndId(&lattice, 0, 3, 1.8, 5); // ABC - const std::vector kTheta = {0.0, 0.01, 0.5, 0.7, 1.0}; - for (int i = 0; i < kTheta.size(); ++i) { + const std::vector kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0}; + for (int i = 0; i < kInv_Theta.size(); ++i) { std::map probs; // Expands all paths in the lattice. - probs["A B C"] = exp(kTheta[i] * (1.0 + 1.2 + 1.5)); // A B C - probs["AB C"] = exp(kTheta[i] * (1.6 + 1.5)); // AB C - probs["A BC"] = exp(kTheta[i] * (1.0 + 1.7)); // A BC - probs["ABC"] = exp(kTheta[i] * 1.8); // ABC + probs["A B C"] = exp(kInv_Theta[i] * (1.0 + 1.2 + 1.5)); // A B C + probs["AB C"] = exp(kInv_Theta[i] * (1.6 + 1.5)); // AB C + probs["A BC"] = exp(kInv_Theta[i] * (1.0 + 1.7)); // A BC + probs["ABC"] = exp(kInv_Theta[i] * 1.8); // ABC // Computes expected probabilities. double Z = 0.0; @@ -453,7 +455,7 @@ TEST(LatticeTest, SampleTest) { constexpr int kTrial = 100000; std::map freq; for (int n = 0; n < kTrial; ++n) { - freq[GetTokenized(lattice.Sample(kTheta[i]))]++; + freq[GetTokenized(lattice.Sample(kInv_Theta[i]))]++; } EXPECT_EQ(probs.size(), freq.size()); @@ -480,18 +482,18 @@ ModelProto MakeBaseModelProto() { } // Returns model protos in parameterized tests. -const std::vector &GetEncoderVersions() { - static const std::vector &v = - *new std::vector{EncoderVersion::kOptimized, - EncoderVersion::kOriginal}; +const std::vector &GetEncoderVersions() { + static const std::vector &v = + *new std::vector{Model::kOptimized, + Model::kOriginal}; return v; } -class UnigramModelTest : public test::TestWithParam { +class UnigramModelTest : public test::TestWithParam { protected: void SetUp() override { encoder_version_ = GetParam(); } void TearDown() override {} - EncoderVersion encoder_version_; + Model::EncoderVersion encoder_version_; }; void AddPiece(ModelProto *model_proto, const std::string &piece, @@ -530,15 +532,15 @@ TEST(UnigramModelTest, SampleEncodeAndScoreTest) { lattice.SetSentence("ABC"); model.PopulateNodes(&lattice); - std::vector kTheta = {0.0, 1.0}; + std::vector kInv_Theta = {0.0, 1.0}; - for (const auto theta : kTheta) { + for (const auto inv_theta : kInv_Theta) { std::vector strings = {"ABC", "AB C", "A BC", "A B C"}; std::map probs; - probs["ABC"] = std::exp(theta * 1.0); - probs["AB C"] = std::exp(theta * (0.2 + 0.1)); - probs["A BC"] = std::exp(theta * (0.0 + 0.5)); - probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1)); + probs["ABC"] = std::exp(inv_theta * 1.0); + probs["AB C"] = std::exp(inv_theta * (0.2 + 0.1)); + probs["A BC"] = std::exp(inv_theta * (0.0 + 0.5)); + probs["A B C"] = std::exp(inv_theta * (0.0 + 0.0 + 0.1)); for (const auto &it : strings) { EXPECT_EQ(1, probs.count(it)); @@ -579,8 +581,8 @@ TEST(UnigramModelTest, SampleEncodeAndScoreTest) { std::map scores; int kTrials = 50000; for (int i = 0; i < kTrials; i++) { - NBestEncodeResult sample = - model.SampleEncodeAndScore("ABC", theta, num_samples, true, false); + NBestEncodeResult sample = model.SampleEncodeAndScore( + "ABC", inv_theta, num_samples, true, false); for (const auto &it : sample) { std::vector tokens; @@ -619,7 +621,7 @@ TEST_P(UnigramModelTest, PieceToIdTest) { AddPiece(&model_proto, "d", 0.4); Model model(model_proto); - EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok()); + model.SetEncoderVersion(encoder_version_); EXPECT_EQ(model_proto.SerializeAsString(), model.model_proto().SerializeAsString()); @@ -677,7 +679,7 @@ TEST_P(UnigramModelTest, PopulateNodesAllUnknownsTest) { ModelProto model_proto = MakeBaseModelProto(); AddPiece(&model_proto, "x"); Model model(model_proto); - EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok()); + model.SetEncoderVersion(encoder_version_); Lattice lattice; lattice.SetSentence("abc"); @@ -701,7 +703,7 @@ TEST_P(UnigramModelTest, PopulateNodesTest) { AddPiece(&model_proto, "bc", 0.4); // 6 Model model(model_proto); - EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok()); + model.SetEncoderVersion(encoder_version_); Lattice lattice; lattice.SetSentence("abc"); @@ -736,7 +738,7 @@ TEST_P(UnigramModelTest, PopulateNodesWithUnusedTest) { model_proto.mutable_pieces(6)->set_type(ModelProto::SentencePiece::UNUSED); Model model(model_proto); - EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok()); + model.SetEncoderVersion(encoder_version_); Lattice lattice; lattice.SetSentence("abc"); @@ -761,7 +763,7 @@ TEST_P(UnigramModelTest, ModelNBestTest) { AddPiece(&model_proto, "abc", 10.0); // 8 Model model(model_proto); - EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok()); + model.SetEncoderVersion(encoder_version_); auto nbest = model.NBestEncode("", 10); EXPECT_EQ(1, nbest.size()); @@ -800,7 +802,7 @@ TEST_P(UnigramModelTest, EncodeTest) { ModelProto::SentencePiece::USER_DEFINED); Model model(model_proto); - EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok()); + model.SetEncoderVersion(encoder_version_); EncodeResult result; @@ -883,7 +885,7 @@ TEST_P(UnigramModelTest, EncodeWithUnusedTest) { // No unused. { Model model(model_proto); - EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok()); + model.SetEncoderVersion(encoder_version_); const auto result = model.Encode("abcd"); EXPECT_EQ(1, result.size()); EXPECT_EQ("abcd", result[0].first); @@ -892,7 +894,7 @@ TEST_P(UnigramModelTest, EncodeWithUnusedTest) { { model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED); Model model(model_proto); - EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok()); + model.SetEncoderVersion(encoder_version_); const auto result = model.Encode("abcd"); EXPECT_EQ(2, result.size()); EXPECT_EQ("abc", result[0].first); @@ -903,7 +905,7 @@ TEST_P(UnigramModelTest, EncodeWithUnusedTest) { model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED); model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::UNUSED); Model model(model_proto); - EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok()); + model.SetEncoderVersion(encoder_version_); const auto result = model.Encode("abcd"); EXPECT_EQ(2, result.size()); EXPECT_EQ("abc", result[0].first); @@ -917,7 +919,7 @@ TEST_P(UnigramModelTest, EncodeWithUnusedTest) { model_proto.mutable_pieces(4)->set_type(ModelProto::SentencePiece::UNUSED); model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::NORMAL); Model model(model_proto); - EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok()); + model.SetEncoderVersion(encoder_version_); const auto result = model.Encode("abcd"); EXPECT_EQ(2, result.size()); EXPECT_EQ("ab", result[0].first); @@ -937,7 +939,7 @@ TEST_P(UnigramModelTest, VerifyOutputsEquivalent) { AddPiece(&model_proto, "c", 2.0); // 9 AddPiece(&model_proto, "d", 1.0); // 10 Model model(model_proto); - EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok()); + model.SetEncoderVersion(encoder_version_); // Equivalent outputs. EXPECT_TRUE(model.VerifyOutputsEquivalent("", "")); EXPECT_TRUE(model.VerifyOutputsEquivalent("a b", "a b")); diff --git a/src/util.h b/src/util.h index 285676d..fb312f1 100644 --- a/src/util.h +++ b/src/util.h @@ -60,17 +60,6 @@ uint32 GetRandomGeneratorSeed(); // String utilities namespace string_util { -struct string_view_hash { - // DJB hash function. - inline size_t operator()(const absl::string_view &sp) const { - size_t hash = 5381; - for (size_t i = 0; i < sp.size(); ++i) { - hash = ((hash << 5) + hash) + sp[i]; - } - return hash; - } -}; - template inline bool lexical_cast(absl::string_view arg, Target *result) { std::stringstream ss;