Added ImmutableSentencePiece class
authorTaku Kudo <taku@google.com>
Sun, 19 Jun 2022 15:55:46 +0000 (00:55 +0900)
committerKentaro Hayashi <kenhys@xdump.org>
Mon, 21 Nov 2022 13:43:46 +0000 (13:43 +0000)
Signed-off-by: Kentaro Hayashi <kenhys@gmail.com>
Gbp-Pq: Name 0010-Added-ImmutableSentencePiece-class.patch

src/bpe_model.cc
src/model_interface.h
src/model_interface_test.cc
src/sentencepiece_processor.cc
src/sentencepiece_processor.h
src/sentencepiece_processor_test.cc
src/unigram_model.cc
src/unigram_model.h
src/unigram_model_test.cc
src/util.h

index 22cd11567d073c97192cc7a5ae05e8aed7e90626..bc7ada13a7848f5043a6d371ad34923f527d8336 100644 (file)
@@ -12,6 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.!
 
+#include "bpe_model.h"
+
 #include <functional>
 #include <memory>
 #include <queue>
@@ -19,7 +21,6 @@
 #include <utility>
 #include <vector>
 
-#include "bpe_model.h"
 #include "freelist.h"
 #include "third_party/absl/container/flat_hash_map.h"
 #include "util.h"
@@ -71,8 +72,7 @@ std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
   // Reverse merge rules.
   // key: merged symbol, value: pair of original symbols.
   absl::flat_hash_map<absl::string_view,
-                      std::pair<absl::string_view, absl::string_view>,
-                      string_util::string_view_hash>
+                      std::pair<absl::string_view, absl::string_view>>
       rev_merge;
 
   // Pre-allocates SymbolPair for efficiency.
index 06b3a6588685f611e50a35335db78069e7ceef7c..06e924302dd892ad80c0b8f452298d2c7de9356e 100644 (file)
@@ -53,8 +53,8 @@ class ModelProto;
 // Given a normalized string, returns a sequence of sentence pieces with ids.
 class ModelInterface {
  public:
-  using PieceToIdMap = absl::flat_hash_map<absl::string_view, int,
-                                           string_util::string_view_hash>;
+  using PieceToIdMap = absl::flat_hash_map<absl::string_view, int>;
+  //                                           string_util::string_view_hash>;
 
   absl::string_view unk_piece() const;
   absl::string_view bos_piece() const;
@@ -77,19 +77,6 @@ class ModelInterface {
     return matcher_.get();
   }
 
-  // Sets the encoder version. Currently only unigram has an optimized encoder.
-  // The optimized version is always used by default if there is one, so
-  // normally users do not need to call this function. This function is provided
-  // just in case that a user want to manually choose which encoder version to
-  // use.
-  virtual util::Status SetEncoderVersion(EncoderVersion encoder_version) {
-    encoder_version_ = encoder_version;
-    return util::OkStatus();
-  }
-
-  // Returns the current encoder version in use.
-  virtual EncoderVersion GetEncoderVersion() const { return encoder_version_; }
-
   // Given a normalized string, returns a sequence of sentence pieces with ids.
   // The concatenation of pieces must be the same as `normalized`.
   virtual EncodeResult Encode(absl::string_view normalized) const = 0;
@@ -123,10 +110,9 @@ class ModelInterface {
   }
 
   // Calculates the entropy of the segmentation lattice with inverse temperature
-  // `theta`.
-  // Uses a novel dynamic program to calculate the entropy.
+  // `alpha`. Uses a novel dynamic program to calculate the entropy.
   virtual float CalculateEntropy(absl::string_view normalized,
-                                 float theta) const {
+                                 float alpha) const {
     LOG(ERROR) << "Not implemented.";
     return 0.0;
   }
@@ -256,10 +242,6 @@ class ModelInterface {
   // unknown id.
   int unk_id_ = 0;
 
-  // The encoder version. Currently it is only effective for unigram model but
-  // ignored by other models.
-  EncoderVersion encoder_version_ = EncoderVersion::kOptimized;
-
   // status.
   util::Status status_;
 };
index 69ee4e60272dec3158525adef9519a5c99b139d8..09e41d34812d9e9a011f5460d3f73c33deda3eb6 100644 (file)
@@ -12,8 +12,9 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.!
 
-#include "model_factory.h"
 #include "model_interface.h"
+
+#include "model_factory.h"
 #include "testharness.h"
 #include "third_party/absl/container/flat_hash_map.h"
 #include "util.h"
@@ -481,22 +482,6 @@ TEST(ModelInterfaceTest, PieceToByteTest) {
   EXPECT_EQ(PieceToByte("a"), -1);
 }
 
-TEST(ModelInterfaceTest, SetEncoderVersion) {
-  for (const auto type : kModelTypes) {
-    ModelProto model_proto = MakeBaseModelProto(type);
-    AddPiece(&model_proto, "a");
-    AddPiece(&model_proto, "b");
-    auto model = ModelFactory::Create(model_proto);
-
-    // Verify the default encoder version.
-    EXPECT_EQ(EncoderVersion::kOptimized, model->GetEncoderVersion());
-
-    // Set the encoder version to original and verify.
-    EXPECT_TRUE(model->SetEncoderVersion(EncoderVersion::kOriginal).ok());
-    EXPECT_EQ(EncoderVersion::kOriginal, model->GetEncoderVersion());
-  }
-}
-
 TEST(ModelInterfaceTest, VerifyOutputsEquivalent) {
   for (const auto type : kModelTypes) {
     ModelProto model_proto = MakeBaseModelProto(type);
index 331fc904d6c9e68a12e5714fd0e916a1314fd9d2..a6f53953a3de9ca27f002c9c701e2624e0d902d8 100644 (file)
@@ -56,6 +56,112 @@ std::vector<absl::string_view> ToPieceArray(const std::vector<std::string> &v) {
 }
 }  // namespace
 
+ImmutableSentencePieceText::ImmutableSentencePieceText() {}
+ImmutableSentencePieceText::~ImmutableSentencePieceText() {}
+
+ImmutableSentencePieceText::ImmutableSentencePieceText(
+    const SentencePieceText &spt)
+    : spt_(&spt) {}
+
+ImmutableSentencePieceText::ImmutableSentencePiece::ImmutableSentencePiece(
+    const SentencePieceText_SentencePiece &sp)
+    : sp_(&sp) {}
+
+absl::string_view ImmutableSentencePieceText::ImmutableSentencePiece::piece()
+    const {
+  return sp_->piece();
+}
+
+absl::string_view ImmutableSentencePieceText::ImmutableSentencePiece::surface()
+    const {
+  return sp_->surface();
+}
+
+uint32_t ImmutableSentencePieceText::ImmutableSentencePiece::id() const {
+  return sp_->id();
+}
+
+uint32_t ImmutableSentencePieceText::ImmutableSentencePiece::begin() const {
+  return sp_->begin();
+}
+
+uint32_t ImmutableSentencePieceText::ImmutableSentencePiece::end() const {
+  return sp_->end();
+}
+
+std::vector<ImmutableSentencePieceText::ImmutableSentencePiece>
+ImmutableSentencePieceText::pieces() const {
+  std::vector<ImmutableSentencePieceText::ImmutableSentencePiece> pieces;
+  if (spt_ == nullptr) return pieces;
+  pieces.reserve(spt_->pieces_size());
+  for (int i = 0; i < spt_->pieces_size(); ++i)
+    pieces[i] = ImmutableSentencePiece(spt_->pieces(i));
+  return pieces;
+}
+
+size_t ImmutableSentencePieceText::pieces_size() const {
+  return spt_ ? spt_->pieces_size() : 0;
+}
+
+ImmutableSentencePieceText::ImmutableSentencePiece
+ImmutableSentencePieceText::pieces(int index) const {
+  return ImmutableSentencePieceText::ImmutableSentencePiece(
+      spt_->pieces(index));
+}
+
+absl::string_view ImmutableSentencePieceText::text() const {
+  return spt_ ? spt_->text() : "";
+}
+
+float ImmutableSentencePieceText::score() const {
+  return spt_ ? spt_->score() : 0.0;
+}
+
+SentencePieceText *ImmutableSentencePieceText::mutable_proto() {
+  if (rep_ == nullptr) {
+    rep_ = std::make_shared<SentencePieceText>();
+    spt_ = rep_.get();
+  }
+  return rep_.get();
+}
+
+std::string ImmutableSentencePieceText::SerializeAsString() const {
+  return spt_ ? spt_->SerializeAsString() : "";
+}
+
+ImmutableNBestSentencePieceText::ImmutableNBestSentencePieceText() {}
+ImmutableNBestSentencePieceText::~ImmutableNBestSentencePieceText() {}
+
+size_t ImmutableNBestSentencePieceText::nbests_size() const {
+  return rep_ ? rep_->nbests_size() : 0;
+}
+
+ImmutableSentencePieceText ImmutableNBestSentencePieceText::nbests(
+    int index) const {
+  return ImmutableSentencePieceText(rep_->nbests(index));
+}
+
+std::vector<ImmutableSentencePieceText>
+ImmutableNBestSentencePieceText::nbests() const {
+  std::vector<ImmutableSentencePieceText> nbests;
+  if (rep_ == nullptr) return nbests;
+  nbests.reserve(rep_->nbests_size());
+  for (int i = 0; i < rep_->nbests_size(); ++i)
+    nbests[i] = ImmutableSentencePieceText(rep_->nbests(i));
+  return nbests;
+}
+
+NBestSentencePieceText *ImmutableNBestSentencePieceText::mutable_proto() {
+  if (rep_ == nullptr) {
+    rep_ = std::make_shared<NBestSentencePieceText>();
+  }
+  return rep_.get();
+}
+
+std::string ImmutableNBestSentencePieceText::SerializeAsString() const {
+  return rep_ ? rep_->SerializeAsString() : "";
+}
+
 SentencePieceProcessor::SentencePieceProcessor() {}
 SentencePieceProcessor::~SentencePieceProcessor() {}
 
@@ -124,15 +230,6 @@ util::Status SentencePieceProcessor::Load(
   return util::OkStatus();
 }
 
-util::Status SentencePieceProcessor::SetEncoderVersion(
-    EncoderVersion encoder_version) {
-  return model_->SetEncoderVersion(encoder_version);
-}
-
-EncoderVersion SentencePieceProcessor::GetEncoderVersion() const {
-  return model_->GetEncoderVersion();
-}
-
 util::Status SentencePieceProcessor::SetEncodeExtraOptions(
     absl::string_view extra_options) {
   return ParseExtraOptions(extra_options, &encode_extra_options_);
@@ -348,14 +445,14 @@ util::Status SentencePieceProcessor::SampleEncode(absl::string_view input,
 }
 
 util::Status SentencePieceProcessor::SampleEncodeAndScore(
-    absl::string_view input, int num_samples, float theta, bool wor,
+    absl::string_view input, int num_samples, float alpha, bool wor,
     bool include_best,
     std::vector<std::pair<std::vector<std::string>, float>> *pieces) const {
   CHECK_OR_RETURN_STATUS_STL(pieces);
 
   NBestSentencePieceText spt;
   RETURN_IF_ERROR(
-      SampleEncodeAndScore(input, num_samples, theta, wor, include_best, &spt));
+      SampleEncodeAndScore(input, num_samples, alpha, wor, include_best, &spt));
 
   pieces->clear();
   pieces->reserve(spt.nbests_size());
@@ -373,14 +470,14 @@ util::Status SentencePieceProcessor::SampleEncodeAndScore(
 }
 
 util::Status SentencePieceProcessor::SampleEncodeAndScore(
-    absl::string_view input, int num_samples, float theta, bool wor,
+    absl::string_view input, int num_samples, float alpha, bool wor,
     bool include_best,
     std::vector<std::pair<std::vector<int>, float>> *ids) const {
   CHECK_OR_RETURN_STATUS_STL(ids);
 
   NBestSentencePieceText spt;
   RETURN_IF_ERROR(
-      SampleEncodeAndScore(input, num_samples, theta, wor, include_best, &spt));
+      SampleEncodeAndScore(input, num_samples, alpha, wor, include_best, &spt));
 
   ids->clear();
   ids->reserve(spt.nbests_size());
@@ -568,7 +665,7 @@ util::Status SentencePieceProcessor::SampleEncode(
 }
 
 util::Status SentencePieceProcessor::SampleEncodeAndScore(
-    absl::string_view input, int samples, float theta, bool wor,
+    absl::string_view input, int samples, float alpha, bool wor,
     bool include_best, NBestSentencePieceText *samples_spt) const {
   CHECK_OR_RETURN(model_->IsSampleEncodeAndScoreAvailable())
       << "SampleEncodeAndScore is not available for the current model.";
@@ -576,7 +673,7 @@ util::Status SentencePieceProcessor::SampleEncodeAndScore(
   std::vector<size_t> norm_to_orig;
   RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
 
-  const auto results = model_->SampleEncodeAndScore(normalized, theta, samples,
+  const auto results = model_->SampleEncodeAndScore(normalized, alpha, samples,
                                                     wor, include_best);
   CHECK_OR_RETURN(!results.empty())
       << "SampleEncodeAndScore returns empty result.";
@@ -592,7 +689,7 @@ util::Status SentencePieceProcessor::SampleEncodeAndScore(
 }
 
 util::Status SentencePieceProcessor::CalculateEntropy(absl::string_view input,
-                                                      float theta,
+                                                      float alpha,
                                                       float *entropy) const {
   CHECK_OR_RETURN(model_->IsCalculateEntropyAvailable())
       << "CalculateEntropy is not available for the current model.";
@@ -600,7 +697,7 @@ util::Status SentencePieceProcessor::CalculateEntropy(absl::string_view input,
   std::vector<size_t> norm_to_orig;
   RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
 
-  *entropy = model_->CalculateEntropy(normalized, theta);
+  *entropy = model_->CalculateEntropy(normalized, alpha);
   return util::OkStatus();
 }
 
@@ -770,48 +867,6 @@ util::Status SentencePieceProcessor::Decode(const std::vector<int> &ids,
   return Decode(pieces, spt);
 }
 
-std::string SentencePieceProcessor::EncodeAsSerializedProto(
-    absl::string_view input) const {
-  SentencePieceText spt;
-  if (!Encode(input, &spt).ok()) return "";
-  return spt.SerializeAsString();
-}
-
-std::string SentencePieceProcessor::SampleEncodeAsSerializedProto(
-    absl::string_view input, int nbest_size, float alpha) const {
-  SentencePieceText spt;
-  if (!SampleEncode(input, nbest_size, alpha, &spt).ok()) return "";
-  return spt.SerializeAsString();
-}
-
-std::string SentencePieceProcessor::NBestEncodeAsSerializedProto(
-    absl::string_view input, int nbest_size) const {
-  NBestSentencePieceText spt;
-  if (!NBestEncode(input, nbest_size, &spt).ok()) return "";
-  return spt.SerializeAsString();
-}
-
-std::string SentencePieceProcessor::DecodePiecesAsSerializedProto(
-    const std::vector<std::string> &pieces) const {
-  SentencePieceText spt;
-  if (!Decode(pieces, &spt).ok()) return "";
-  return spt.SerializeAsString();
-}
-
-std::string SentencePieceProcessor::DecodePiecesAsSerializedProto(
-    const std::vector<absl::string_view> &pieces) const {
-  SentencePieceText spt;
-  if (!Decode(pieces, &spt).ok()) return "";
-  return spt.SerializeAsString();
-}
-
-std::string SentencePieceProcessor::DecodeIdsAsSerializedProto(
-    const std::vector<int> &ids) const {
-  SentencePieceText spt;
-  if (!Decode(ids, &spt).ok()) return "";
-  return spt.SerializeAsString();
-}
-
 #define CHECK_STATUS_OR_RETURN_DEFAULT(value)                                \
   if (!status().ok()) {                                                      \
     LOG(ERROR) << status().message() << "\nReturns default value " << value; \
index 8c72656dace7dff5024251a5d22a8299b257b1f7..51c5b3bac7ea33bbbacdcc1ed89ef70e2b4b7eb4 100644 (file)
@@ -29,11 +29,6 @@ using std::string_view;
 #endif  // SWIG
 
 namespace sentencepiece {
-
-#ifndef SWIG
-using EncodeResult = std::vector<std::pair<absl::string_view, int>>;
-#endif  // SWIG
-
 namespace util {
 
 enum class StatusCode : int {
@@ -107,17 +102,17 @@ class Status {
 //   sp.Load("//path/to/model");
 //
 //   vector<string> sps;
-//   sp.Encode("hello world.", &sps);
+//   sp.Encode("hello world.", &sps).IgnoreError();
 //
 //   vector<int> ids;
-//   sp.Encode("hello world.", &ids);
+//   sp.Encode("hello world.", &ids).IgnoreError();
 //
 //   string detok;
 //   sp.Decode(sps, &detok);
-//   CHECK_EQ("hello world.", detok);
+//   CHECK_EQ("hello world.", detok).IgnoreError();
 //
 //   sp.Decode(ids, &detok);
-//   CHECK_EQ("hello world.", detok);
+//   CHECK_EQ("hello world.", detok).IgnoreError();
 //
 //  We can also use SentencePieceText which manages the byte-offsets
 //  between user input (output) and internal sentence pieces.
@@ -144,16 +139,6 @@ namespace normalizer {
 class Normalizer;
 }  // namespace normalizer
 
-#ifndef SWIG
-// Defines the multiple versions of encoder within each model. Currently only
-// the Unigram model has an optimized encoder.
-enum class EncoderVersion {
-  kOptimized,  // The optimized encoder (default).
-  kOriginal    // The original encoder (user may choose to fall back to this
-               // just in case).
-};
-#endif
-
 #ifndef SWIGGO
 namespace util {
 // Redefine std::string for serialized_proto interface as Python's string is
@@ -161,7 +146,87 @@ namespace util {
 // with SWIG's typemap.
 using bytes = std::string;
 }  // namespace util
-#endif
+#endif  // SWIGGO
+
+class NBestSentencePieceText;
+class ModelInterface;
+class SentencePieceText;
+class SentencePieceText_SentencePiece;
+
+// Wrapper class of SentencePieceText
+// This wrapper only allows an immutable access to the proto and
+// hides the actual implementation of protobuf.
+// See sentencepiece.proto for the details of this class.
+class ImmutableSentencePieceText {
+ public:
+  ImmutableSentencePieceText();
+  virtual ~ImmutableSentencePieceText();
+
+  class ImmutableSentencePiece {
+   public:
+    ~ImmutableSentencePiece() = default;
+    absl::string_view piece() const;
+    absl::string_view surface() const;
+    uint32_t id() const;
+    uint32_t begin() const;
+    uint32_t end() const;
+
+    friend class ImmutableSentencePieceText;
+
+   private:
+    ImmutableSentencePiece() = default;
+    explicit ImmutableSentencePiece(const SentencePieceText_SentencePiece &sp);
+    const SentencePieceText_SentencePiece *sp_ = nullptr;
+  };
+
+  std::vector<ImmutableSentencePiece> pieces() const;
+  size_t pieces_size() const;
+  ImmutableSentencePiece pieces(int index) const;
+  absl::string_view text() const;
+  float score() const;
+
+  std::string SerializeAsString() const;
+
+  // Returns the actual mutable proto.
+  // Do not use this outside of SentencePieceProcessor, as
+  // it returns the raw pointer managed by the shared_ptr.
+  SentencePieceText *mutable_proto();
+
+  friend class ImmutableNBestSentencePieceText;
+  friend class SentencePieceProcessor;
+
+ private:
+  explicit ImmutableSentencePieceText(const SentencePieceText &spt);
+  const SentencePieceText *spt_ = nullptr;
+  std::shared_ptr<SentencePieceText> rep_;
+};
+
+// Wrapper class of SentencePieceText
+// This wrapper only allows an immutable access to the proto and
+// hides the actual implementation of protobuf.
+// See sentencepiece.proto for the details of this class.
+class ImmutableNBestSentencePieceText {
+ public:
+  ImmutableNBestSentencePieceText();
+  virtual ~ImmutableNBestSentencePieceText();
+
+  std::vector<ImmutableSentencePieceText> nbests() const;
+
+  size_t nbests_size() const;
+  ImmutableSentencePieceText nbests(int index) const;
+
+  std::string SerializeAsString() const;
+
+  // Returns the actual mutable proto.
+  // Do not use this outside of SentencePieceProcessor, as
+  // it returns the raw pointer managed by the shared_ptr.
+  NBestSentencePieceText *mutable_proto();
+
+  friend class SentencePieceProcessor;
+
+ private:
+  std::shared_ptr<NBestSentencePieceText> rep_;
+};
 
 class SentencePieceProcessor {
  public:
@@ -217,7 +282,7 @@ class SentencePieceProcessor {
                                       int threshold);
 
   //////////////////////////////////////////////////////////////
-  // Simple API.
+  // Simple Encode and Decode API.
   //
   // Given a UTF8 input, encodes it into a sequence of sentence pieces.
   virtual util::Status Encode(absl::string_view input,
@@ -239,18 +304,9 @@ class SentencePieceProcessor {
   virtual util::Status Decode(const std::vector<int> &ids,
                               std::string *detokenized) const;
 
-#ifndef SWIG
-  // Sets the encoder version. Normally users do not need to call this function.
-  // But they can call this fucntion just in case if they want to fall back to
-  // the original encoder.
-  virtual util::Status SetEncoderVersion(EncoderVersion encoder_version);
-
-  // Returns the current encoder version in use.
-  virtual EncoderVersion GetEncoderVersion() const;
-#endif
-
   //////////////////////////////////////////////////////////////
   // NBest API.
+  //
   // Same as Encode, but returns nbest results.
   virtual util::Status NBestEncode(
       absl::string_view input, int nbest_size,
@@ -262,24 +318,24 @@ class SentencePieceProcessor {
 
   //////////////////////////////////////////////////////////////
   // Sampling API.
+  //
   // Unigram and BPE support sampling mode.
   // - Unigram (--model_type=unigram):
-  // When `nbest_size` is positive value, approximately samples one
-  // segmentation from nbest candidates. When `nbest_size` is negative value,
-  // samples one segmentation from the hypotheses (Lattice) according to the
-  // generation probabilities using forward-filtering and backward-sampling
-  // algorithm. `alpha` is a smoothing parameter.  The best segmentation
-  // (Viterbi segmentation) is more likely sampled when setting larger
-  // alpha. When alpha is 0.0, one segmentation is uniformly sampled from the
-  // nbest or lattice.
-  // `nbest_size` and `alpha` correspond to parameters `l` and `alpha`
+  // `nbest_size`: When `nbest_size` is positive value, approximately samples
+  // one segmentation from nbest candidates. When `nbest_size` is negative
+  // value, samples one segmentation from the hypotheses (Lattice) according to
+  // the generation probabilities using forward-filtering and backward-sampling
+  // algorithm.
+  // `alpha`: Smoothing parameter (inverse temperature). The best segmentation
+  // (Viterbi segmentation) is more likely sampled when setting larger alpha.
+  // When alpha is 0.0, one segmentation is uniformly sampled from the nbest or
+  // lattice. `nbest_size` and `alpha` correspond to parameters `l` and `alpha`
   // in https://arxiv.org/abs/1804.10959  (nbest_size < 0 means l = infinity)
   //
   // - BPE (--model_type=bpe):
-  // `alpha` is the dropout probability `p` of bpe merge operations
-  // in https://arxiv.org/abs/1910.13267
-  // Nbest-based sampling is not supported so nbest_size parameter is ignored in
-  // BPE.
+  // `alpha`: The dropout probability `p` of bpe merge operations in
+  // https://arxiv.org/abs/1910.13267 Nbest-based sampling is not supported so
+  // nbest_size parameter is ignored in BPE.
   virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
                                     float alpha,
                                     std::vector<std::string> *pieces) const;
@@ -290,74 +346,104 @@ class SentencePieceProcessor {
 
   //////////////////////////////////////////////////////////////
   // SampleEncodeAndScore API.
-  // Similar to SampleEncode, but returns samples results.
+  //
+  // Sample `samples` many tokenisations from the segmentation lattice.
+  // These methods are only available in model_type=unigram.
+  //
+  // `alpha`: smoothing parameter (inverse temperature). The same as `alpha` in
+  // `Sample` method.
+  // 'wor`: If `wor` is true, the samples are taken without replacement, and the
+  // scores are the inclusion probabilities of the elements in the sample;
+  // otherwise the samples are taken with replacement and the scores are the
+  // log-probs of sample elements
+  // `include_best`: If `include_best` is true, the best tokenisation is always
+  // included in the sample, and the remaining elements are sampled excluding
+  // the best.
   virtual util::Status SampleEncodeAndScore(
-      absl::string_view input, int num_samples, float theta, bool wor,
+      absl::string_view input, int num_samples, float alpha, bool wor,
       bool include_best,
       std::vector<std::pair<std::vector<std::string>, float>> *pieces) const;
 
   // Same as above, but returns a sequence of ids.
   virtual util::Status SampleEncodeAndScore(
-      absl::string_view input, int num_samples, float theta, bool wor,
+      absl::string_view input, int num_samples, float alpha, bool wor,
       bool include_best,
       std::vector<std::pair<std::vector<int>, float>> *ids) const;
 
+  //////////////////////////////////////////////////////////////
+  // Entropy API.
+  //
+  // This only available in model_type=unigram.
+  // Calculate entropy of possible tokenisations
+  virtual util::Status CalculateEntropy(absl::string_view input, float alpha,
+                                        float *entropy) const;
+
   //////////////////////////////////////////////////////////////
   // Advanced API returning SentencePieceText, which manages
   // utf8-byte alignments between user-input/detokenized text
   // and internal sentencepiece sequence.
   //
   // Given a UTF8 input, encodes it into SentencePieceText.
+  //
+  // When using these APIs, sentencepiece.pb.h header files must be included.
+  // We can also use ImutableSentencePieceText as follows.
+  //
+  // ImmutableSentencePieceText spt;
+  // Encode("hello", spt.mutable_proto()).IgnoreError();
+  // std::cout << spt.pieces_size() << std::endl;
   virtual util::Status Encode(absl::string_view input,
                               SentencePieceText *spt) const;
 
-  // Same as above, but returns NBestSentencePieceText.
   virtual util::Status NBestEncode(absl::string_view input, int nbest_size,
                                    NBestSentencePieceText *nbest_spt) const;
 
-  // Same as above, but samples one segmentation from the hypotheses
-  // (Lattice).
   virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
                                     float alpha, SentencePieceText *spt) const;
 
-  // Samples N segmentation and returns the scores as well
   virtual util::Status SampleEncodeAndScore(
-      absl::string_view input, int samples, float theta, bool wor,
+      absl::string_view input, int samples, float alpha, bool wor,
       bool include_best, NBestSentencePieceText *samples_spt) const;
 
-  // Calculate entropy of possible tokenisations
-  virtual util::Status CalculateEntropy(absl::string_view input, float theta,
-                                        float *entropy) const;
-
-  // Given a sequence of pieces, decodes it into SentencePieceText.
-  // TODO(taku): Remove this API and use std::vector<std::string_view>
+  // DEPRECATED: Remove this API and use std::vector<std::string_view>
   virtual util::Status Decode(const std::vector<std::string> &pieces,
                               SentencePieceText *spt) const;
 
-  // Given a sequence of pieces, decodes it into SentencePieceText.
   virtual util::Status Decode(const std::vector<absl::string_view> &pieces,
                               SentencePieceText *spt) const;
 
-  // Given a sequence of ids, decodes it into SentencePieceText.
   virtual util::Status Decode(const std::vector<int> &ids,
                               SentencePieceText *spt) const;
 
-  //////////////////////////////////////////////////////////////
-  // Handy methods that return the result directly.
-  // These functions ignore internal errors.
 #ifdef SWIG
-#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
-  OutType output;                                           \
-  const auto _status = FuncName(__VA_ARGS__, &output);      \
-  if (!_status.ok()) throw _status;                         \
-  return output;
+#define SPP_SWIG_CHECK_AND_THROW \
+  if (!status.ok()) throw status;
 #else
+#define SPP_SWIG_CHECK_AND_THROW \
+  if (!status.ok()) {            \
+  }
+#endif  // SWIG
+
 #define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
   OutType output;                                           \
-  FuncName(__VA_ARGS__, &output).IgnoreError();             \
+  const auto status = FuncName(__VA_ARGS__, &output);       \
+  SPP_SWIG_CHECK_AND_THROW;                                 \
+  return output;
+
+#define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...)     \
+  OutType output;                                                    \
+  const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
+  SPP_SWIG_CHECK_AND_THROW;                                          \
+  return output.SerializeAsString();
+
+#define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...)      \
+  OutType output;                                                    \
+  const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
+  SPP_SWIG_CHECK_AND_THROW;                                          \
   return output;
-#endif
 
+  //////////////////////////////////////////////////////////////
+  // Handy methods that return the result directly.
+  // These functions ignore internal errors.
   virtual std::vector<std::string> EncodeAsPieces(
       absl::string_view input) const {
     DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<std::string>, input);
@@ -395,21 +481,21 @@ class SentencePieceProcessor {
 
   virtual std::vector<std::pair<std::vector<std::string>, float>>
   SampleEncodeAndScoreAsPieces(absl::string_view input, int num_samples,
-                               float theta, bool wor, bool include_best) const {
+                               float alpha, bool wor, bool include_best) const {
     using _T = std::vector<std::pair<std::vector<std::string>, float>>;
     DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
-                                theta, wor, include_best);
+                                alpha, wor, include_best);
   }
 
   virtual std::vector<std::pair<std::vector<int>, float>>
   SampleEncodeAndScoreAsIds(absl::string_view input, int num_samples,
-                            float theta, bool wor, bool include_best) const {
+                            float alpha, bool wor, bool include_best) const {
     using _T = std::vector<std::pair<std::vector<int>, float>>;
     DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
-                                theta, wor, include_best);
+                                alpha, wor, include_best);
   }
 
-  // TODO(taku): Remove this API and use std::vector<std::string_view>
+  // DEPRECATED: Remove this API and use std::vector<std::string_view>
   virtual std::string DecodePieces(
       const std::vector<std::string> &pieces) const {
     DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
@@ -424,33 +510,104 @@ class SentencePieceProcessor {
     DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, ids);
   }
 
-  virtual float CalculateEntropy(absl::string_view text, float theta) const {
-    DEFINE_SPP_DIRECT_FUNC_IMPL(CalculateEntropy, float, text, theta);
+  virtual float CalculateEntropy(absl::string_view text, float alpha) const {
+    DEFINE_SPP_DIRECT_FUNC_IMPL(CalculateEntropy, float, text, alpha);
   }
 
-#undef DEFINE_SPP_DIRECT_FUNC_IMPL
-
+  //////////////////////////////////////////////////////////////
+  // SerializedProto API. (DEPRECATED). Use ImmutableProto API.
   // They are used in Python interface. Returns serialized proto.
   // In python module, we can get access to the full Proto after
   // deserialzing the returned byte sequence.
-  virtual util::bytes EncodeAsSerializedProto(absl::string_view input) const;
+  virtual util::bytes EncodeAsSerializedProto(absl::string_view input) const {
+    DEFINE_SPP_SERIALIZED_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
+  }
 
   virtual util::bytes SampleEncodeAsSerializedProto(absl::string_view input,
                                                     int nbest_size,
-                                                    float alpha) const;
+                                                    float alpha) const {
+    DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
+                                     input, nbest_size, alpha);
+  }
 
   virtual util::bytes NBestEncodeAsSerializedProto(absl::string_view input,
-                                                   int nbest_size) const;
+                                                   int nbest_size) const {
+    DEFINE_SPP_SERIALIZED_PROTO_IMPL(
+        NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
+  }
+
+  virtual util::bytes SampleEncodeAndScoreAsSerializedProto(
+      absl::string_view input, int samples, float alpha, bool wor,
+      bool include_best, int nbest_size) const {
+    DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncodeAndScore,
+                                     ImmutableNBestSentencePieceText, input,
+                                     samples, alpha, wor, include_best);
+  }
 
   // TODO(taku): Remove this API and use std::vector<std::string_view>
   virtual util::bytes DecodePiecesAsSerializedProto(
-      const std::vector<std::string> &pieces) const;
+      const std::vector<std::string> &pieces) const {
+    DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
+                                     pieces);
+  }
 
   virtual util::bytes DecodePiecesAsSerializedProto(
-      const std::vector<absl::string_view> &pieces) const;
+      const std::vector<absl::string_view> &pieces) const {
+    DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
+                                     pieces);
+  }
 
   virtual util::bytes DecodeIdsAsSerializedProto(
-      const std::vector<int> &ids) const;
+      const std::vector<int> &ids) const {
+    DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
+  }
+
+  //////////////////////////////////////////////////////////////
+  // ImmutableProto API.
+  virtual ImmutableSentencePieceText EncodeAsImmutableProto(
+      absl::string_view input) const {
+    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
+  }
+
+  virtual ImmutableSentencePieceText SampleEncodeAsImmutableProto(
+      absl::string_view input, int nbest_size, float alpha) const {
+    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
+                                    input, nbest_size, alpha);
+  }
+
+  virtual ImmutableNBestSentencePieceText NBestEncodeAsImmutableProto(
+      absl::string_view input, int nbest_size) const {
+    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(
+        NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
+  }
+
+  virtual ImmutableNBestSentencePieceText SampleEncodeAndScoreAsImmutableProto(
+      absl::string_view input, int samples, float alpha, bool wor,
+      bool include_best, int nbest_size) const {
+    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncodeAndScore,
+                                    ImmutableNBestSentencePieceText, input,
+                                    samples, alpha, wor, include_best);
+  }
+
+  // TODO(taku): Remove this API and use std::vector<std::string_view>
+  virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
+      const std::vector<std::string> &pieces) const {
+    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
+  }
+
+  virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
+      const std::vector<absl::string_view> &pieces) const {
+    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
+  }
+
+  virtual ImmutableSentencePieceText DecodeIdsAsImmutableProto(
+      const std::vector<int> &ids) const {
+    DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
+  }
+
+#undef DEFINE_SPP_DIRECT_FUNC_IMPL
+#undef DEFINE_SPP_SERIALIZED_PROTO_IMPL
+#undef DEFINE_SPP_IMMUTABLE_PROTO_IMPL
 
   //////////////////////////////////////////////////////////////
   // Vocabulary management methods.
@@ -467,7 +624,8 @@ class SentencePieceProcessor {
   virtual const std::string &IdToPiece(int id) const;
 
   // Returns the score of `id`.
-  // Usually score is an emission log probability of unigram language model.
+  // Usually score is an emission log probability of unigram language
+  // model.
   virtual float GetScore(int id) const;
 
   // Returns true if `id` is unknown symbol.
@@ -506,7 +664,7 @@ class SentencePieceProcessor {
 
   // Allows injection of a normalizer instance. `normalizer` is moved.
   void SetNormalizer(std::unique_ptr<normalizer::Normalizer> &&normalizer);
-#endif
+#endif  // SWIG
 
   // Returns immutable model proto. Useful to obtain extended
   // or experimental parameters encoded in model_proto.
index d57ab5a3e4d897d3f88922e9f2f7b0d71d44be27..ed651f7c83e97d6eface33bc271618ab96b0ad9e 100644 (file)
@@ -12,6 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.!
 
+#include "sentencepiece_processor.h"
+
 #include <utility>
 
 #include "builder.h"
@@ -20,7 +22,6 @@
 #include "normalizer.h"
 #include "sentencepiece.pb.h"
 #include "sentencepiece_model.pb.h"
-#include "sentencepiece_processor.h"
 #include "sentencepiece_trainer.h"
 #include "testharness.h"
 #include "third_party/absl/container/flat_hash_map.h"
@@ -551,10 +552,9 @@ TEST(SentencepieceProcessorTest, DecodeTest) {
     int GetPieceSize() const override { return 7; }
 
     int PieceToId(absl::string_view piece) const override {
-      static absl::flat_hash_map<absl::string_view, int,
-                                 string_util::string_view_hash>
-          kMap = {{"<unk>", 0}, {"<s>", 1}, {"</s>", 2},    {WS "ABC", 3},
-                  {WS "DE", 4}, {"F", 5},   {"G" WS "H", 6}};
+      static absl::flat_hash_map<absl::string_view, int> kMap = {
+          {"<unk>", 0}, {"<s>", 1}, {"</s>", 2},    {WS "ABC", 3},
+          {WS "DE", 4}, {"F", 5},   {"G" WS "H", 6}};
       return port::FindWithDefault(kMap, piece, 0);
     }
 
@@ -719,10 +719,9 @@ TEST(SentencepieceProcessorTest, DummyPrefixDecodeTest) {
     int GetPieceSize() const override { return 7; }
 
     int PieceToId(absl::string_view piece) const override {
-      static absl::flat_hash_map<absl::string_view, int,
-                                 string_util::string_view_hash>
-          kMap = {{"<unk>", 0}, {"<s>", 1}, {"</s>", 2},     {WS "ABC", 3},
-                  {WS "DE", 4}, {"F", 5},   {"G" WS "H", 6}, {WS, 7}};
+      static absl::flat_hash_map<absl::string_view, int> kMap = {
+          {"<unk>", 0}, {"<s>", 1}, {"</s>", 2},     {WS "ABC", 3},
+          {WS "DE", 4}, {"F", 5},   {"G" WS "H", 6}, {WS, 7}};
       return port::FindWithDefault(kMap, piece, 0);
     }
 
@@ -1058,18 +1057,6 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
   EXPECT_EQ(2, sp.eos_id());
   EXPECT_EQ(-1, sp.pad_id());
 
-  {
-    // Verify the default encoder version.
-    EXPECT_EQ(EncoderVersion::kOptimized, sp.GetEncoderVersion());
-
-    // Set the encoder version to original and verify.
-    EXPECT_TRUE(sp.SetEncoderVersion(EncoderVersion::kOriginal).ok());
-    EXPECT_EQ(EncoderVersion::kOriginal, sp.GetEncoderVersion());
-
-    // Set back to the default encoder version.
-    EXPECT_TRUE(sp.SetEncoderVersion(EncoderVersion::kOptimized).ok());
-  }
-
   {
     std::vector<std::string> sps;
     const std::vector<std::string> expected_str = {WS, "ab", "c"};
@@ -1574,4 +1561,77 @@ TEST(SentencePieceProcessorTest, VocabularyTest) {
   EXPECT_FALSE(sp.IsUnused(6));
   EXPECT_FALSE(sp.IsUnused(7));
 }
+
+TEST(SentencePieceProcessorTest, ImmutableSentencePieceTextTest) {
+  ImmutableSentencePieceText spt;
+  auto *v = spt.mutable_proto();
+
+  v->set_text("hello world");
+  v->set_score(1.0);
+  for (int i = 0; i < 10; ++i) {
+    auto *p = v->add_pieces();
+    p->set_surface(absl::StrCat("surface_", i));
+    p->set_piece(absl::StrCat("surface_", i));
+    p->set_id(i);
+    p->set_begin(i + 10);
+    p->set_end(i + 20);
+  }
+
+  EXPECT_EQ(v->pieces_size(), spt.pieces_size());
+  for (int i = 0; i < spt.pieces_size(); ++i) {
+    EXPECT_EQ(v->pieces(i).surface(), spt.pieces(i).surface());
+    EXPECT_EQ(v->pieces(i).piece(), spt.pieces(i).piece());
+    EXPECT_EQ(v->pieces(i).id(), spt.pieces(i).id());
+    EXPECT_EQ(v->pieces(i).begin(), spt.pieces(i).begin());
+    EXPECT_EQ(v->pieces(i).end(), spt.pieces(i).end());
+  }
+
+  int n = 0;
+  for (auto &p : spt.pieces()) {
+    EXPECT_EQ(v->pieces(n).surface(), p.surface());
+    EXPECT_EQ(v->pieces(n).piece(), p.piece());
+    EXPECT_EQ(v->pieces(n).id(), p.id());
+    EXPECT_EQ(v->pieces(n).begin(), p.begin());
+    EXPECT_EQ(v->pieces(n).end(), p.end());
+    ++n;
+  }
+
+  EXPECT_EQ(v->text(), spt.text());
+  EXPECT_EQ(v->score(), spt.score());
+  EXPECT_EQ(v->SerializeAsString(), spt.SerializeAsString());
+
+  // test copy.
+  auto spt2 = spt;
+  EXPECT_EQ(spt2.pieces_size(), spt.pieces_size());
+  for (int i = 0; i < spt.pieces_size(); ++i) {
+    EXPECT_EQ(spt2.pieces(i).surface(), spt.pieces(i).surface());
+    EXPECT_EQ(spt2.pieces(i).piece(), spt.pieces(i).piece());
+    EXPECT_EQ(spt2.pieces(i).id(), spt.pieces(i).id());
+    EXPECT_EQ(spt2.pieces(i).begin(), spt.pieces(i).begin());
+    EXPECT_EQ(spt2.pieces(i).end(), spt.pieces(i).end());
+  }
+}
+
+TEST(SentencePieceProcessorTest, ImmutableNBestSentencePieceTextTest) {
+  ImmutableNBestSentencePieceText spt;
+  auto *v = spt.mutable_proto();
+  for (int i = 0; i < 10; ++i) {
+    auto *p = v->add_nbests();
+    p->set_text(absl::StrCat("text_", i));
+    p->set_score(2.0 * i);
+  }
+
+  EXPECT_EQ(v->nbests_size(), spt.nbests_size());
+  for (int i = 0; i < v->nbests_size(); ++i) {
+    EXPECT_EQ(v->nbests(i).text(), spt.nbests(i).text());
+    EXPECT_EQ(v->nbests(i).score(), spt.nbests(i).score());
+  }
+  EXPECT_EQ(v->SerializeAsString(), spt.SerializeAsString());
+
+  // test copy.
+  auto spt2 = spt;
+  EXPECT_EQ(spt2.nbests_size(), spt.nbests_size());
+  EXPECT_EQ(spt2.SerializeAsString(), spt.SerializeAsString());
+}
+
 }  // namespace sentencepiece
index ea4891290af576254262497eb9bbddf8a0d3afab..d9f1ce9d521fcf2196d2c624ff5776626343b03a 100644 (file)
@@ -198,16 +198,17 @@ Lattice::LatticePathWithScore Lattice::Viterbi() {
   return retval;
 }
 
-std::vector<float> Lattice::ForwardAlgorithm(float theta) const {
+std::vector<float> Lattice::ForwardAlgorithm(float inv_theta) const {
   const int len = size();
   std::vector<float> alpha(node_allocator_.size(), 0.0);
 
   for (int pos = 0; pos <= len; ++pos) {
     for (Node *rnode : begin_nodes_[pos]) {
       for (Node *lnode : end_nodes_[pos]) {
-        alpha[rnode->node_id] = LogSumExp(
-            alpha[rnode->node_id], theta * lnode->score + alpha[lnode->node_id],
-            lnode == end_nodes_[pos][0]);
+        alpha[rnode->node_id] =
+            LogSumExp(alpha[rnode->node_id],
+                      inv_theta * lnode->score + alpha[lnode->node_id],
+                      lnode == end_nodes_[pos][0]);
       }
     }
   }
@@ -215,7 +216,7 @@ std::vector<float> Lattice::ForwardAlgorithm(float theta) const {
   return alpha;
 }
 
-std::vector<float> Lattice::BackwardAlgorithm(float theta) const {
+std::vector<float> Lattice::BackwardAlgorithm(float inv_theta) const {
   const int len = size();
   std::vector<float> beta(node_allocator_.size(), 0.0);
 
@@ -260,17 +261,16 @@ float Lattice::PopulateMarginal(float freq,
   return freq * Z;
 }
 
-float Lattice::CalculateEntropy(float theta) const {
+float Lattice::CalculateEntropy(float inv_theta) const {
   const int len = size();
 
   // alpha[node_id] is the marginal prob of sequence up to start of node
   // H is entropy of sequence
   // the index of alpha/H is Node::node_id.
-  std::vector<float> alpha(node_allocator_.size(), 0.0);
   std::vector<float> H(node_allocator_.size(), 0.0);
 
   // Populate the forward marginals to get the normalising constant
-  alpha = ForwardAlgorithm(theta);
+  const auto alpha = ForwardAlgorithm(inv_theta);
 
   // Now populate the forward entropies
   for (int pos = 0; pos <= len; ++pos) {
@@ -280,7 +280,7 @@ float Lattice::CalculateEntropy(float theta) const {
 
         // We have to normalise p(lnode) by the marginal contribution it makes
         const float lnode_transition_prob =
-            ((theta * lnode->score) + alpha[lnode->node_id] -
+            ((inv_theta * lnode->score) + alpha[lnode->node_id] -
              alpha[rnode->node_id]);
         H[rnode->node_id] += std::exp(lnode_transition_prob) *
                              (H[lnode->node_id] + lnode_transition_prob);
@@ -345,7 +345,7 @@ Hypothesis *CloneHypAndDependents(
 
 std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
                                                           bool sample,
-                                                          float theta) {
+                                                          float inv_theta) {
   if (nbest_size < 1) {
     LOG(WARNING) << "nbest_size >= 1. Returns empty result.";
     return {};
@@ -391,7 +391,7 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
 
   if (sample) {
     // Run forwards algorithm to get normalising constants
-    alpha = ForwardAlgorithm(theta);
+    alpha = ForwardAlgorithm(inv_theta);
     // f(eos) = Gumbel(0), as it is the perturbed score of the entire lattice.
     eos->fx = Gumbel();
   } else {
@@ -432,7 +432,8 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
       for (int i = 0; i < end_nodes(node->pos).size(); i++) {
         Node *lnode = end_nodes(node->pos)[i];
         // Calculate backwards transition score
-        probs[i] = top->gx + alpha[lnode->node_id] + (theta * lnode->score) - Z;
+        probs[i] =
+            top->gx + alpha[lnode->node_id] + (inv_theta * lnode->score) - Z;
         perturbed_probs[i] = probs[i] + Gumbel();
         if (perturbed_probs[i] > max_score) {
           max_score = perturbed_probs[i];
@@ -508,13 +509,13 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
   return results;
 }
 
-std::vector<Lattice::Node *> Lattice::Sample(float theta) {
+std::vector<Lattice::Node *> Lattice::Sample(float inv_theta) {
   const int len = size();
   if (len == 0) return {};
 
   std::vector<float> alpha(node_allocator_.size(), 0.0);
 
-  alpha = ForwardAlgorithm(theta);
+  alpha = ForwardAlgorithm(inv_theta);
 
   auto *mt = random::GetRandomGenerator();
 
@@ -526,8 +527,8 @@ std::vector<Lattice::Node *> Lattice::Sample(float theta) {
   while (true) {
     probs.clear();
     for (const Node *lnode : end_nodes_[node->pos]) {
-      probs.push_back(std::exp(static_cast<double>(alpha[lnode->node_id] +
-                                                   theta * lnode->score - Z)));
+      probs.push_back(std::exp(static_cast<double>(
+          alpha[lnode->node_id] + inv_theta * lnode->score - Z)));
     }
     std::discrete_distribution<int> dist(probs.begin(), probs.end());
     node = end_nodes_[node->pos][dist(*mt)];
@@ -721,7 +722,7 @@ NBestEncodeResult Model::NBestEncode(absl::string_view normalized,
 }
 
 EncodeResult Model::SampleEncode(absl::string_view normalized,
-                                 float theta) const {
+                                 float inv_theta) const {
   if (!status().ok() || normalized.empty()) {
     return {};
   }
@@ -731,7 +732,7 @@ EncodeResult Model::SampleEncode(absl::string_view normalized,
   PopulateNodes(&lattice);
 
   EncodeResult results;
-  for (const auto *node : lattice.Sample(theta)) {
+  for (const auto *node : lattice.Sample(inv_theta)) {
     results.emplace_back(node->piece, node->id);
   }
 
@@ -739,7 +740,7 @@ EncodeResult Model::SampleEncode(absl::string_view normalized,
 }
 
 NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized,
-                                              float theta, int samples,
+                                              float inv_theta, int samples,
                                               bool wor,
                                               bool include_best) const {
   if (!status().ok() || normalized.empty()) {
@@ -750,16 +751,16 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized,
   lattice.SetSentence(normalized);
   PopulateNodes(&lattice);
 
-  std::vector<float> alpha = lattice.ForwardAlgorithm(theta);
-  float marginal = alpha[lattice.eos_node()->node_id];
+  const std::vector<float> alpha = lattice.ForwardAlgorithm(inv_theta);
+  const float marginal = alpha[lattice.eos_node()->node_id];
 
   if (include_best) {
     if (!wor) {
-      LOG(FATAL) << "include_best not supported for wor false";
+      LOG(ERROR) << "include_best not supported for wor false";
+      return {};
     }
     EncodeResult result;
-    Lattice::LatticePathWithScore best_path = lattice.Viterbi();
-
+    const auto best_path = lattice.Viterbi();
     for (const auto *node : best_path.first) {
       result.emplace_back(node->piece, node->id);
     }
@@ -770,8 +771,7 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized,
 
   if (wor) {
     // Draw k+1 samples as we need perturbed score of k+1th element
-    std::vector<Lattice::LatticePathWithScore> nbest_samples =
-        lattice.NBest(samples + 1, true, theta);
+    auto nbest_samples = lattice.NBest(samples + 1, true, inv_theta);
 
     if (include_best) {
       std::vector<std::vector<Lattice::Node *>> nbest_paths(
@@ -780,14 +780,13 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized,
         nbest_paths[i] = nbest_samples[i].first;
       }
       // Remove the best result from the samples if necessary
-      Lattice::LatticePathWithScore best_path = lattice.Viterbi();
+      const auto best_path = lattice.Viterbi();
 
       const int index_of_best =
           (std::find(nbest_paths.begin(), nbest_paths.end(), best_path.first) -
            nbest_paths.begin());
 
       if (index_of_best != nbest_samples.size()) {
-        LOG(INFO) << "removing best path from samples";
         nbest_samples.erase(nbest_samples.begin() + index_of_best);
       } else {
         nbest_samples.pop_back();
@@ -803,7 +802,7 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized,
       float score = 0.0;
 
       for (const auto *node : nbest.first) {
-        score += (theta * node->score);
+        score += (inv_theta * node->score);
         result.emplace_back(node->piece, node->id);
       }
 
@@ -814,8 +813,8 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized,
     for (auto &it : results) {
       // Only modify non best sample inclusion probabilities.
       if (it.second != 0.0) {
-        double x = it.second - kappa;
-        double y = std::exp(x);
+        const double x = it.second - kappa;
+        const double y = std::exp(x);
         double inclusion_prob;
         if (x <= -10) {
           // Series expansion of the log Gumbel survival function up to eps.
@@ -835,10 +834,10 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized,
 
       float score = 0.0;
       EncodeResult result;
-      std::vector<Lattice::Node *> sample = lattice.Sample(theta);
+      const std::vector<Lattice::Node *> sample = lattice.Sample(inv_theta);
       for (const auto *node : sample) {
         result.emplace_back(node->piece, node->id);
-        score += (theta * node->score);
+        score += (inv_theta * node->score);
       }
       results.emplace_back(result, score - marginal);
     }
@@ -847,12 +846,13 @@ NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized,
   return results;
 }
 
-float Model::CalculateEntropy(absl::string_view normalized, float theta) const {
+float Model::CalculateEntropy(absl::string_view normalized,
+                              float inv_theta) const {
   Lattice lattice;
   lattice.SetSentence(normalized);
   PopulateNodes(&lattice);
 
-  return lattice.CalculateEntropy(theta);
+  return lattice.CalculateEntropy(inv_theta);
 }
 
 bool Model::VerifyOutputsEquivalent(absl::string_view expected,
index 448e489d88011e3490ced51de76bb098186c2df0..aa4f28f3b409fe173a58d2e4c2414b32be13542d 100644 (file)
@@ -173,6 +173,18 @@ class Model : public ModelInterface {
   bool VerifyOutputsEquivalent(absl::string_view expected,
                                absl::string_view actual) const override;
 
+  enum EncoderVersion {
+    kOptimized,  // The optimized encoder.
+    kOriginal    // The original encoder.
+  };
+
+  void SetEncoderVersion(EncoderVersion encoder_version) {
+    encoder_version_ = encoder_version;
+  }
+
+  // Returns the current encoder version in use.
+  EncoderVersion GetEncoderVersion() const { return encoder_version_; }
+
  protected:
   // Builds a Trie index.
   void BuildTrie(std::vector<std::pair<absl::string_view, int>> *pieces);
@@ -195,6 +207,9 @@ class Model : public ModelInterface {
   // Maximum size of the return value of Trie, which corresponds
   // to the maximum size of shared common prefix in the sentence pieces.
   int trie_results_size_;
+
+  // encoder version.
+  EncoderVersion encoder_version_ = kOptimized;
 };
 
 }  // namespace unigram
index 8049d20f2daa3cfbde9141a5754a7b2ce2596ece..221bac200d2a5713eba280ed4730203635254f92 100644 (file)
@@ -12,6 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.!
 
+#include "unigram_model.h"
+
 #include <cmath>
 #include <map>
 #include <string>
@@ -22,7 +24,6 @@
 #include "testharness.h"
 #include "third_party/absl/strings/str_cat.h"
 #include "third_party/absl/strings/str_join.h"
-#include "unigram_model.h"
 #include "util.h"
 
 namespace sentencepiece {
@@ -249,14 +250,14 @@ TEST(LatticeTest, NBestSampleTest) {
 
   // Calculate expected probabilities of each path
   // Note that sampling without replacement affects the expected frequencies!
-  const std::vector<double> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
-  for (const auto theta : kTheta) {
+  const std::vector<double> kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0};
+  for (const auto inv_theta : kInv_Theta) {
     std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
     std::map<std::string, float> probs;
-    probs["ABC"] = std::exp(theta * 1.0);
-    probs["AB C"] = std::exp(theta * (0.2 + 0.1));
-    probs["A BC"] = std::exp(theta * (0.0 + 0.5));
-    probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
+    probs["ABC"] = std::exp(inv_theta * 1.0);
+    probs["AB C"] = std::exp(inv_theta * (0.2 + 0.1));
+    probs["A BC"] = std::exp(inv_theta * (0.0 + 0.5));
+    probs["A B C"] = std::exp(inv_theta * (0.0 + 0.0 + 0.1));
 
     for (const auto &it : strings) {
       EXPECT_EQ(1, probs.count(it));
@@ -298,7 +299,7 @@ TEST(LatticeTest, NBestSampleTest) {
     for (const auto num_samples : kNumSamples) {
       std::map<std::string, int> counts;
       for (int i = 0; i < kTrials; i++) {
-        auto nbests = lattice.NBest(num_samples, true, theta);
+        auto nbests = lattice.NBest(num_samples, true, inv_theta);
         for (const auto &nbest : nbests) {
           counts[GetTokenized(nbest.first)]++;
         }
@@ -329,14 +330,14 @@ TEST(LatticeTest, CalculateEntropyTest) {
   InsertWithScore(&lattice, 0, 3, 1.0);  // ABC
 
   // Calculate expected probabilities of each path
-  const std::vector<double> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
-  for (const auto theta : kTheta) {
+  const std::vector<double> kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0};
+  for (const auto inv_theta : kInv_Theta) {
     std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
     std::map<std::string, float> probs;
-    probs["ABC"] = std::exp(theta * 1.0);
-    probs["AB C"] = std::exp(theta * (0.2 + 0.1));
-    probs["A BC"] = std::exp(theta * (0.0 + 0.5));
-    probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
+    probs["ABC"] = std::exp(inv_theta * 1.0);
+    probs["AB C"] = std::exp(inv_theta * (0.2 + 0.1));
+    probs["A BC"] = std::exp(inv_theta * (0.0 + 0.5));
+    probs["A B C"] = std::exp(inv_theta * (0.0 + 0.0 + 0.1));
 
     double Z = 0.0;
     for (const auto &it : probs) Z += it.second;
@@ -349,7 +350,7 @@ TEST(LatticeTest, CalculateEntropyTest) {
     for (const auto &it : probs) {
       entropy += (it.second * std::log(it.second));
     }
-    EXPECT_NEAR(-entropy, lattice.CalculateEntropy(theta), 0.02);
+    EXPECT_NEAR(-entropy, lattice.CalculateEntropy(inv_theta), 0.02);
   }
 }
 
@@ -364,9 +365,9 @@ TEST(LatticeTest, ForwardAlgorithmTest) {
   InsertWithScore(&lattice, 1, 2, 0.5);  // BC
   InsertWithScore(&lattice, 0, 3, 1.0);  // ABC
 
-  const std::vector<float> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
-  for (const auto theta : kTheta) {
-    std::vector<float> alpha = lattice.ForwardAlgorithm(theta);
+  const std::vector<float> kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0};
+  for (const auto inv_theta : kInv_Theta) {
+    std::vector<float> alpha = lattice.ForwardAlgorithm(inv_theta);
     EXPECT_EQ(alpha.size(), 8);  // 6 nodes, plus BOS, EOS
     // only alpha[C], alpha[EOS] have non-zero alpha
     for (int i : {0, 1, 2, 3}) {
@@ -374,14 +375,15 @@ TEST(LatticeTest, ForwardAlgorithmTest) {
         if (i < 2) {
           EXPECT_EQ(alpha[node->node_id], 0.0);
         } else if (i == 2) {
-          float Z =
-              std::log(std::exp(theta * (0.0 + 0.0)) + std::exp(theta * 0.2));
+          float Z = std::log(std::exp(inv_theta * (0.0 + 0.0)) +
+                             std::exp(inv_theta * 0.2));
           EXPECT_EQ(alpha[node->node_id], Z);
         } else if (i == 3) {
-          float Z = std::log(std::exp(theta * (0.0 + 0.0 + 0.1)) +  // A + B + C
-                             std::exp(theta * (0.2 + 0.1)) +        // AB + C
-                             std::exp(theta * (0.0 + 0.5)) +        // A + BC
-                             std::exp(theta * 1.0));                // ABC
+          float Z =
+              std::log(std::exp(inv_theta * (0.0 + 0.0 + 0.1)) +  // A + B + C
+                       std::exp(inv_theta * (0.2 + 0.1)) +        // AB + C
+                       std::exp(inv_theta * (0.0 + 0.5)) +        // A + BC
+                       std::exp(inv_theta * 1.0));                // ABC
           EXPECT_EQ(Z, alpha[node->node_id]);
         }
       }
@@ -435,14 +437,14 @@ TEST(LatticeTest, SampleTest) {
   InsertWithScoreAndId(&lattice, 1, 2, 1.7, 4);  // BC
   InsertWithScoreAndId(&lattice, 0, 3, 1.8, 5);  // ABC
 
-  const std::vector<double> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
-  for (int i = 0; i < kTheta.size(); ++i) {
+  const std::vector<double> kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0};
+  for (int i = 0; i < kInv_Theta.size(); ++i) {
     std::map<std::string, double> probs;
     // Expands all paths in the lattice.
-    probs["A B C"] = exp(kTheta[i] * (1.0 + 1.2 + 1.5));  // A B C
-    probs["AB C"] = exp(kTheta[i] * (1.6 + 1.5));         // AB C
-    probs["A BC"] = exp(kTheta[i] * (1.0 + 1.7));         // A BC
-    probs["ABC"] = exp(kTheta[i] * 1.8);                  // ABC
+    probs["A B C"] = exp(kInv_Theta[i] * (1.0 + 1.2 + 1.5));  // A B C
+    probs["AB C"] = exp(kInv_Theta[i] * (1.6 + 1.5));         // AB C
+    probs["A BC"] = exp(kInv_Theta[i] * (1.0 + 1.7));         // A BC
+    probs["ABC"] = exp(kInv_Theta[i] * 1.8);                  // ABC
 
     // Computes expected probabilities.
     double Z = 0.0;
@@ -453,7 +455,7 @@ TEST(LatticeTest, SampleTest) {
     constexpr int kTrial = 100000;
     std::map<std::string, int> freq;
     for (int n = 0; n < kTrial; ++n) {
-      freq[GetTokenized(lattice.Sample(kTheta[i]))]++;
+      freq[GetTokenized(lattice.Sample(kInv_Theta[i]))]++;
     }
 
     EXPECT_EQ(probs.size(), freq.size());
@@ -480,18 +482,18 @@ ModelProto MakeBaseModelProto() {
 }
 
 // Returns model protos in parameterized tests.
-const std::vector<EncoderVersion> &GetEncoderVersions() {
-  static const std::vector<EncoderVersion> &v =
-      *new std::vector<EncoderVersion>{EncoderVersion::kOptimized,
-                                       EncoderVersion::kOriginal};
+const std::vector<Model::EncoderVersion> &GetEncoderVersions() {
+  static const std::vector<Model::EncoderVersion> &v =
+      *new std::vector<Model::EncoderVersion>{Model::kOptimized,
+                                              Model::kOriginal};
   return v;
 }
 
-class UnigramModelTest : public test::TestWithParam<EncoderVersion> {
+class UnigramModelTest : public test::TestWithParam<Model::EncoderVersion> {
  protected:
   void SetUp() override { encoder_version_ = GetParam(); }
   void TearDown() override {}
-  EncoderVersion encoder_version_;
+  Model::EncoderVersion encoder_version_;
 };
 
 void AddPiece(ModelProto *model_proto, const std::string &piece,
@@ -530,15 +532,15 @@ TEST(UnigramModelTest, SampleEncodeAndScoreTest) {
   lattice.SetSentence("ABC");
   model.PopulateNodes(&lattice);
 
-  std::vector<float> kTheta = {0.0, 1.0};
+  std::vector<float> kInv_Theta = {0.0, 1.0};
 
-  for (const auto theta : kTheta) {
+  for (const auto inv_theta : kInv_Theta) {
     std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
     std::map<std::string, float> probs;
-    probs["ABC"] = std::exp(theta * 1.0);
-    probs["AB C"] = std::exp(theta * (0.2 + 0.1));
-    probs["A BC"] = std::exp(theta * (0.0 + 0.5));
-    probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
+    probs["ABC"] = std::exp(inv_theta * 1.0);
+    probs["AB C"] = std::exp(inv_theta * (0.2 + 0.1));
+    probs["A BC"] = std::exp(inv_theta * (0.0 + 0.5));
+    probs["A B C"] = std::exp(inv_theta * (0.0 + 0.0 + 0.1));
 
     for (const auto &it : strings) {
       EXPECT_EQ(1, probs.count(it));
@@ -579,8 +581,8 @@ TEST(UnigramModelTest, SampleEncodeAndScoreTest) {
       std::map<std::string, float> scores;
       int kTrials = 50000;
       for (int i = 0; i < kTrials; i++) {
-        NBestEncodeResult sample =
-            model.SampleEncodeAndScore("ABC", theta, num_samples, true, false);
+        NBestEncodeResult sample = model.SampleEncodeAndScore(
+            "ABC", inv_theta, num_samples, true, false);
 
         for (const auto &it : sample) {
           std::vector<std::string> tokens;
@@ -619,7 +621,7 @@ TEST_P(UnigramModelTest, PieceToIdTest) {
   AddPiece(&model_proto, "d", 0.4);
 
   Model model(model_proto);
-  EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+  model.SetEncoderVersion(encoder_version_);
 
   EXPECT_EQ(model_proto.SerializeAsString(),
             model.model_proto().SerializeAsString());
@@ -677,7 +679,7 @@ TEST_P(UnigramModelTest, PopulateNodesAllUnknownsTest) {
   ModelProto model_proto = MakeBaseModelProto();
   AddPiece(&model_proto, "x");
   Model model(model_proto);
-  EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+  model.SetEncoderVersion(encoder_version_);
 
   Lattice lattice;
   lattice.SetSentence("abc");
@@ -701,7 +703,7 @@ TEST_P(UnigramModelTest, PopulateNodesTest) {
   AddPiece(&model_proto, "bc", 0.4);  // 6
 
   Model model(model_proto);
-  EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+  model.SetEncoderVersion(encoder_version_);
 
   Lattice lattice;
   lattice.SetSentence("abc");
@@ -736,7 +738,7 @@ TEST_P(UnigramModelTest, PopulateNodesWithUnusedTest) {
   model_proto.mutable_pieces(6)->set_type(ModelProto::SentencePiece::UNUSED);
 
   Model model(model_proto);
-  EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+  model.SetEncoderVersion(encoder_version_);
 
   Lattice lattice;
   lattice.SetSentence("abc");
@@ -761,7 +763,7 @@ TEST_P(UnigramModelTest, ModelNBestTest) {
   AddPiece(&model_proto, "abc", 10.0);  // 8
 
   Model model(model_proto);
-  EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+  model.SetEncoderVersion(encoder_version_);
 
   auto nbest = model.NBestEncode("", 10);
   EXPECT_EQ(1, nbest.size());
@@ -800,7 +802,7 @@ TEST_P(UnigramModelTest, EncodeTest) {
       ModelProto::SentencePiece::USER_DEFINED);
 
   Model model(model_proto);
-  EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+  model.SetEncoderVersion(encoder_version_);
 
   EncodeResult result;
 
@@ -883,7 +885,7 @@ TEST_P(UnigramModelTest, EncodeWithUnusedTest) {
   // No unused.
   {
     Model model(model_proto);
-    EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+    model.SetEncoderVersion(encoder_version_);
     const auto result = model.Encode("abcd");
     EXPECT_EQ(1, result.size());
     EXPECT_EQ("abcd", result[0].first);
@@ -892,7 +894,7 @@ TEST_P(UnigramModelTest, EncodeWithUnusedTest) {
   {
     model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
     Model model(model_proto);
-    EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+    model.SetEncoderVersion(encoder_version_);
     const auto result = model.Encode("abcd");
     EXPECT_EQ(2, result.size());
     EXPECT_EQ("abc", result[0].first);
@@ -903,7 +905,7 @@ TEST_P(UnigramModelTest, EncodeWithUnusedTest) {
     model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
     model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::UNUSED);
     Model model(model_proto);
-    EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+    model.SetEncoderVersion(encoder_version_);
     const auto result = model.Encode("abcd");
     EXPECT_EQ(2, result.size());
     EXPECT_EQ("abc", result[0].first);
@@ -917,7 +919,7 @@ TEST_P(UnigramModelTest, EncodeWithUnusedTest) {
     model_proto.mutable_pieces(4)->set_type(ModelProto::SentencePiece::UNUSED);
     model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::NORMAL);
     Model model(model_proto);
-    EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+    model.SetEncoderVersion(encoder_version_);
     const auto result = model.Encode("abcd");
     EXPECT_EQ(2, result.size());
     EXPECT_EQ("ab", result[0].first);
@@ -937,7 +939,7 @@ TEST_P(UnigramModelTest, VerifyOutputsEquivalent) {
   AddPiece(&model_proto, "c", 2.0);      // 9
   AddPiece(&model_proto, "d", 1.0);      // 10
   Model model(model_proto);
-  EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+  model.SetEncoderVersion(encoder_version_);
   // Equivalent outputs.
   EXPECT_TRUE(model.VerifyOutputsEquivalent("", ""));
   EXPECT_TRUE(model.VerifyOutputsEquivalent("a b", "a b"));
index 285676dad28af384e59dfcad01aaa4f97dfaf32f..fb312f100982f8efd0458c58b8dedce473cdf9f2 100644 (file)
@@ -60,17 +60,6 @@ uint32 GetRandomGeneratorSeed();
 // String utilities
 namespace string_util {
 
-struct string_view_hash {
-  // DJB hash function.
-  inline size_t operator()(const absl::string_view &sp) const {
-    size_t hash = 5381;
-    for (size_t i = 0; i < sp.size(); ++i) {
-      hash = ((hash << 5) + hash) + sp[i];
-    }
-    return hash;
-  }
-};
-
 template <typename Target>
 inline bool lexical_cast(absl::string_view arg, Target *result) {
   std::stringstream ss;