// See the License for the specific language governing permissions and
// limitations under the License.!
+#include "bpe_model.h"
+
#include <functional>
#include <memory>
#include <queue>
#include <utility>
#include <vector>
-#include "bpe_model.h"
#include "freelist.h"
#include "third_party/absl/container/flat_hash_map.h"
#include "util.h"
// Reverse merge rules.
// key: merged symbol, value: pair of original symbols.
absl::flat_hash_map<absl::string_view,
- std::pair<absl::string_view, absl::string_view>,
- string_util::string_view_hash>
+ std::pair<absl::string_view, absl::string_view>>
rev_merge;
// Pre-allocates SymbolPair for efficiency.
// Given a normalized string, returns a sequence of sentence pieces with ids.
class ModelInterface {
public:
- using PieceToIdMap = absl::flat_hash_map<absl::string_view, int,
- string_util::string_view_hash>;
+ using PieceToIdMap = absl::flat_hash_map<absl::string_view, int>;
+ // string_util::string_view_hash>;
absl::string_view unk_piece() const;
absl::string_view bos_piece() const;
return matcher_.get();
}
- // Sets the encoder version. Currently only unigram has an optimized encoder.
- // The optimized version is always used by default if there is one, so
- // normally users do not need to call this function. This function is provided
- // just in case that a user want to manually choose which encoder version to
- // use.
- virtual util::Status SetEncoderVersion(EncoderVersion encoder_version) {
- encoder_version_ = encoder_version;
- return util::OkStatus();
- }
-
- // Returns the current encoder version in use.
- virtual EncoderVersion GetEncoderVersion() const { return encoder_version_; }
-
// Given a normalized string, returns a sequence of sentence pieces with ids.
// The concatenation of pieces must be the same as `normalized`.
virtual EncodeResult Encode(absl::string_view normalized) const = 0;
}
// Calculates the entropy of the segmentation lattice with inverse temperature
- // `theta`.
- // Uses a novel dynamic program to calculate the entropy.
+ // `alpha`. Uses a novel dynamic program to calculate the entropy.
virtual float CalculateEntropy(absl::string_view normalized,
- float theta) const {
+ float alpha) const {
LOG(ERROR) << "Not implemented.";
return 0.0;
}
// unknown id.
int unk_id_ = 0;
- // The encoder version. Currently it is only effective for unigram model but
- // ignored by other models.
- EncoderVersion encoder_version_ = EncoderVersion::kOptimized;
-
// status.
util::Status status_;
};
// See the License for the specific language governing permissions and
// limitations under the License.!
-#include "model_factory.h"
#include "model_interface.h"
+
+#include "model_factory.h"
#include "testharness.h"
#include "third_party/absl/container/flat_hash_map.h"
#include "util.h"
EXPECT_EQ(PieceToByte("a"), -1);
}
-TEST(ModelInterfaceTest, SetEncoderVersion) {
- for (const auto type : kModelTypes) {
- ModelProto model_proto = MakeBaseModelProto(type);
- AddPiece(&model_proto, "a");
- AddPiece(&model_proto, "b");
- auto model = ModelFactory::Create(model_proto);
-
- // Verify the default encoder version.
- EXPECT_EQ(EncoderVersion::kOptimized, model->GetEncoderVersion());
-
- // Set the encoder version to original and verify.
- EXPECT_TRUE(model->SetEncoderVersion(EncoderVersion::kOriginal).ok());
- EXPECT_EQ(EncoderVersion::kOriginal, model->GetEncoderVersion());
- }
-}
-
TEST(ModelInterfaceTest, VerifyOutputsEquivalent) {
for (const auto type : kModelTypes) {
ModelProto model_proto = MakeBaseModelProto(type);
}
} // namespace
+ImmutableSentencePieceText::ImmutableSentencePieceText() {}
+ImmutableSentencePieceText::~ImmutableSentencePieceText() {}
+
+ImmutableSentencePieceText::ImmutableSentencePieceText(
+ const SentencePieceText &spt)
+ : spt_(&spt) {}
+
+ImmutableSentencePieceText::ImmutableSentencePiece::ImmutableSentencePiece(
+ const SentencePieceText_SentencePiece &sp)
+ : sp_(&sp) {}
+
+absl::string_view ImmutableSentencePieceText::ImmutableSentencePiece::piece()
+ const {
+ return sp_->piece();
+}
+
+absl::string_view ImmutableSentencePieceText::ImmutableSentencePiece::surface()
+ const {
+ return sp_->surface();
+}
+
+uint32_t ImmutableSentencePieceText::ImmutableSentencePiece::id() const {
+ return sp_->id();
+}
+
+uint32_t ImmutableSentencePieceText::ImmutableSentencePiece::begin() const {
+ return sp_->begin();
+}
+
+uint32_t ImmutableSentencePieceText::ImmutableSentencePiece::end() const {
+ return sp_->end();
+}
+
+std::vector<ImmutableSentencePieceText::ImmutableSentencePiece>
+ImmutableSentencePieceText::pieces() const {
+ std::vector<ImmutableSentencePieceText::ImmutableSentencePiece> pieces;
+ if (spt_ == nullptr) return pieces;
+ pieces.reserve(spt_->pieces_size());
+ for (int i = 0; i < spt_->pieces_size(); ++i)
+ pieces[i] = ImmutableSentencePiece(spt_->pieces(i));
+ return pieces;
+}
+
+size_t ImmutableSentencePieceText::pieces_size() const {
+ return spt_ ? spt_->pieces_size() : 0;
+}
+
+ImmutableSentencePieceText::ImmutableSentencePiece
+ImmutableSentencePieceText::pieces(int index) const {
+ return ImmutableSentencePieceText::ImmutableSentencePiece(
+ spt_->pieces(index));
+}
+
+absl::string_view ImmutableSentencePieceText::text() const {
+ return spt_ ? spt_->text() : "";
+}
+
+float ImmutableSentencePieceText::score() const {
+ return spt_ ? spt_->score() : 0.0;
+}
+
+SentencePieceText *ImmutableSentencePieceText::mutable_proto() {
+ if (rep_ == nullptr) {
+ rep_ = std::make_shared<SentencePieceText>();
+ spt_ = rep_.get();
+ }
+ return rep_.get();
+}
+
+std::string ImmutableSentencePieceText::SerializeAsString() const {
+ return spt_ ? spt_->SerializeAsString() : "";
+}
+
+ImmutableNBestSentencePieceText::ImmutableNBestSentencePieceText() {}
+ImmutableNBestSentencePieceText::~ImmutableNBestSentencePieceText() {}
+
+size_t ImmutableNBestSentencePieceText::nbests_size() const {
+ return rep_ ? rep_->nbests_size() : 0;
+}
+
+ImmutableSentencePieceText ImmutableNBestSentencePieceText::nbests(
+ int index) const {
+ return ImmutableSentencePieceText(rep_->nbests(index));
+}
+
+std::vector<ImmutableSentencePieceText>
+ImmutableNBestSentencePieceText::nbests() const {
+ std::vector<ImmutableSentencePieceText> nbests;
+ if (rep_ == nullptr) return nbests;
+ nbests.reserve(rep_->nbests_size());
+ for (int i = 0; i < rep_->nbests_size(); ++i)
+ nbests[i] = ImmutableSentencePieceText(rep_->nbests(i));
+ return nbests;
+}
+
+NBestSentencePieceText *ImmutableNBestSentencePieceText::mutable_proto() {
+ if (rep_ == nullptr) {
+ rep_ = std::make_shared<NBestSentencePieceText>();
+ }
+ return rep_.get();
+}
+
+std::string ImmutableNBestSentencePieceText::SerializeAsString() const {
+ return rep_ ? rep_->SerializeAsString() : "";
+}
+
SentencePieceProcessor::SentencePieceProcessor() {}
SentencePieceProcessor::~SentencePieceProcessor() {}
return util::OkStatus();
}
-util::Status SentencePieceProcessor::SetEncoderVersion(
- EncoderVersion encoder_version) {
- return model_->SetEncoderVersion(encoder_version);
-}
-
-EncoderVersion SentencePieceProcessor::GetEncoderVersion() const {
- return model_->GetEncoderVersion();
-}
-
util::Status SentencePieceProcessor::SetEncodeExtraOptions(
absl::string_view extra_options) {
return ParseExtraOptions(extra_options, &encode_extra_options_);
}
util::Status SentencePieceProcessor::SampleEncodeAndScore(
- absl::string_view input, int num_samples, float theta, bool wor,
+ absl::string_view input, int num_samples, float alpha, bool wor,
bool include_best,
std::vector<std::pair<std::vector<std::string>, float>> *pieces) const {
CHECK_OR_RETURN_STATUS_STL(pieces);
NBestSentencePieceText spt;
RETURN_IF_ERROR(
- SampleEncodeAndScore(input, num_samples, theta, wor, include_best, &spt));
+ SampleEncodeAndScore(input, num_samples, alpha, wor, include_best, &spt));
pieces->clear();
pieces->reserve(spt.nbests_size());
}
util::Status SentencePieceProcessor::SampleEncodeAndScore(
- absl::string_view input, int num_samples, float theta, bool wor,
+ absl::string_view input, int num_samples, float alpha, bool wor,
bool include_best,
std::vector<std::pair<std::vector<int>, float>> *ids) const {
CHECK_OR_RETURN_STATUS_STL(ids);
NBestSentencePieceText spt;
RETURN_IF_ERROR(
- SampleEncodeAndScore(input, num_samples, theta, wor, include_best, &spt));
+ SampleEncodeAndScore(input, num_samples, alpha, wor, include_best, &spt));
ids->clear();
ids->reserve(spt.nbests_size());
}
util::Status SentencePieceProcessor::SampleEncodeAndScore(
- absl::string_view input, int samples, float theta, bool wor,
+ absl::string_view input, int samples, float alpha, bool wor,
bool include_best, NBestSentencePieceText *samples_spt) const {
CHECK_OR_RETURN(model_->IsSampleEncodeAndScoreAvailable())
<< "SampleEncodeAndScore is not available for the current model.";
std::vector<size_t> norm_to_orig;
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
- const auto results = model_->SampleEncodeAndScore(normalized, theta, samples,
+ const auto results = model_->SampleEncodeAndScore(normalized, alpha, samples,
wor, include_best);
CHECK_OR_RETURN(!results.empty())
<< "SampleEncodeAndScore returns empty result.";
}
util::Status SentencePieceProcessor::CalculateEntropy(absl::string_view input,
- float theta,
+ float alpha,
float *entropy) const {
CHECK_OR_RETURN(model_->IsCalculateEntropyAvailable())
<< "CalculateEntropy is not available for the current model.";
std::vector<size_t> norm_to_orig;
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
- *entropy = model_->CalculateEntropy(normalized, theta);
+ *entropy = model_->CalculateEntropy(normalized, alpha);
return util::OkStatus();
}
return Decode(pieces, spt);
}
-std::string SentencePieceProcessor::EncodeAsSerializedProto(
- absl::string_view input) const {
- SentencePieceText spt;
- if (!Encode(input, &spt).ok()) return "";
- return spt.SerializeAsString();
-}
-
-std::string SentencePieceProcessor::SampleEncodeAsSerializedProto(
- absl::string_view input, int nbest_size, float alpha) const {
- SentencePieceText spt;
- if (!SampleEncode(input, nbest_size, alpha, &spt).ok()) return "";
- return spt.SerializeAsString();
-}
-
-std::string SentencePieceProcessor::NBestEncodeAsSerializedProto(
- absl::string_view input, int nbest_size) const {
- NBestSentencePieceText spt;
- if (!NBestEncode(input, nbest_size, &spt).ok()) return "";
- return spt.SerializeAsString();
-}
-
-std::string SentencePieceProcessor::DecodePiecesAsSerializedProto(
- const std::vector<std::string> &pieces) const {
- SentencePieceText spt;
- if (!Decode(pieces, &spt).ok()) return "";
- return spt.SerializeAsString();
-}
-
-std::string SentencePieceProcessor::DecodePiecesAsSerializedProto(
- const std::vector<absl::string_view> &pieces) const {
- SentencePieceText spt;
- if (!Decode(pieces, &spt).ok()) return "";
- return spt.SerializeAsString();
-}
-
-std::string SentencePieceProcessor::DecodeIdsAsSerializedProto(
- const std::vector<int> &ids) const {
- SentencePieceText spt;
- if (!Decode(ids, &spt).ok()) return "";
- return spt.SerializeAsString();
-}
-
#define CHECK_STATUS_OR_RETURN_DEFAULT(value) \
if (!status().ok()) { \
LOG(ERROR) << status().message() << "\nReturns default value " << value; \
#endif // SWIG
namespace sentencepiece {
-
-#ifndef SWIG
-using EncodeResult = std::vector<std::pair<absl::string_view, int>>;
-#endif // SWIG
-
namespace util {
enum class StatusCode : int {
// sp.Load("//path/to/model");
//
// vector<string> sps;
-// sp.Encode("hello world.", &sps);
+// sp.Encode("hello world.", &sps).IgnoreError();
//
// vector<int> ids;
-// sp.Encode("hello world.", &ids);
+// sp.Encode("hello world.", &ids).IgnoreError();
//
// string detok;
// sp.Decode(sps, &detok);
-// CHECK_EQ("hello world.", detok);
+// CHECK_EQ("hello world.", detok).IgnoreError();
//
// sp.Decode(ids, &detok);
-// CHECK_EQ("hello world.", detok);
+// CHECK_EQ("hello world.", detok).IgnoreError();
//
// We can also use SentencePieceText which manages the byte-offsets
// between user input (output) and internal sentence pieces.
class Normalizer;
} // namespace normalizer
-#ifndef SWIG
-// Defines the multiple versions of encoder within each model. Currently only
-// the Unigram model has an optimized encoder.
-enum class EncoderVersion {
- kOptimized, // The optimized encoder (default).
- kOriginal // The original encoder (user may choose to fall back to this
- // just in case).
-};
-#endif
-
#ifndef SWIGGO
namespace util {
// Redefine std::string for serialized_proto interface as Python's string is
// with SWIG's typemap.
using bytes = std::string;
} // namespace util
-#endif
+#endif // SWIGGO
+
+class NBestSentencePieceText;
+class ModelInterface;
+class SentencePieceText;
+class SentencePieceText_SentencePiece;
+
+// Wrapper class of SentencePieceText
+// This wrapper only allows an immutable access to the proto and
+// hides the actual implementation of protobuf.
+// See sentencepiece.proto for the details of this class.
+class ImmutableSentencePieceText {
+ public:
+ ImmutableSentencePieceText();
+ virtual ~ImmutableSentencePieceText();
+
+ class ImmutableSentencePiece {
+ public:
+ ~ImmutableSentencePiece() = default;
+ absl::string_view piece() const;
+ absl::string_view surface() const;
+ uint32_t id() const;
+ uint32_t begin() const;
+ uint32_t end() const;
+
+ friend class ImmutableSentencePieceText;
+
+ private:
+ ImmutableSentencePiece() = default;
+ explicit ImmutableSentencePiece(const SentencePieceText_SentencePiece &sp);
+ const SentencePieceText_SentencePiece *sp_ = nullptr;
+ };
+
+ std::vector<ImmutableSentencePiece> pieces() const;
+ size_t pieces_size() const;
+ ImmutableSentencePiece pieces(int index) const;
+ absl::string_view text() const;
+ float score() const;
+
+ std::string SerializeAsString() const;
+
+ // Returns the actual mutable proto.
+ // Do not use this outside of SentencePieceProcessor, as
+ // it returns the raw pointer managed by the shared_ptr.
+ SentencePieceText *mutable_proto();
+
+ friend class ImmutableNBestSentencePieceText;
+ friend class SentencePieceProcessor;
+
+ private:
+ explicit ImmutableSentencePieceText(const SentencePieceText &spt);
+ const SentencePieceText *spt_ = nullptr;
+ std::shared_ptr<SentencePieceText> rep_;
+};
+
+// Wrapper class of SentencePieceText
+// This wrapper only allows an immutable access to the proto and
+// hides the actual implementation of protobuf.
+// See sentencepiece.proto for the details of this class.
+class ImmutableNBestSentencePieceText {
+ public:
+ ImmutableNBestSentencePieceText();
+ virtual ~ImmutableNBestSentencePieceText();
+
+ std::vector<ImmutableSentencePieceText> nbests() const;
+
+ size_t nbests_size() const;
+ ImmutableSentencePieceText nbests(int index) const;
+
+ std::string SerializeAsString() const;
+
+ // Returns the actual mutable proto.
+ // Do not use this outside of SentencePieceProcessor, as
+ // it returns the raw pointer managed by the shared_ptr.
+ NBestSentencePieceText *mutable_proto();
+
+ friend class SentencePieceProcessor;
+
+ private:
+ std::shared_ptr<NBestSentencePieceText> rep_;
+};
class SentencePieceProcessor {
public:
int threshold);
//////////////////////////////////////////////////////////////
- // Simple API.
+ // Simple Encode and Decode API.
//
// Given a UTF8 input, encodes it into a sequence of sentence pieces.
virtual util::Status Encode(absl::string_view input,
virtual util::Status Decode(const std::vector<int> &ids,
std::string *detokenized) const;
-#ifndef SWIG
- // Sets the encoder version. Normally users do not need to call this function.
- // But they can call this fucntion just in case if they want to fall back to
- // the original encoder.
- virtual util::Status SetEncoderVersion(EncoderVersion encoder_version);
-
- // Returns the current encoder version in use.
- virtual EncoderVersion GetEncoderVersion() const;
-#endif
-
//////////////////////////////////////////////////////////////
// NBest API.
+ //
// Same as Encode, but returns nbest results.
virtual util::Status NBestEncode(
absl::string_view input, int nbest_size,
//////////////////////////////////////////////////////////////
// Sampling API.
+ //
// Unigram and BPE support sampling mode.
// - Unigram (--model_type=unigram):
- // When `nbest_size` is positive value, approximately samples one
- // segmentation from nbest candidates. When `nbest_size` is negative value,
- // samples one segmentation from the hypotheses (Lattice) according to the
- // generation probabilities using forward-filtering and backward-sampling
- // algorithm. `alpha` is a smoothing parameter. The best segmentation
- // (Viterbi segmentation) is more likely sampled when setting larger
- // alpha. When alpha is 0.0, one segmentation is uniformly sampled from the
- // nbest or lattice.
- // `nbest_size` and `alpha` correspond to parameters `l` and `alpha`
+ // `nbest_size`: When `nbest_size` is positive value, approximately samples
+ // one segmentation from nbest candidates. When `nbest_size` is negative
+ // value, samples one segmentation from the hypotheses (Lattice) according to
+ // the generation probabilities using forward-filtering and backward-sampling
+ // algorithm.
+ // `alpha`: Smoothing parameter (inverse temperature). The best segmentation
+ // (Viterbi segmentation) is more likely sampled when setting larger alpha.
+ // When alpha is 0.0, one segmentation is uniformly sampled from the nbest or
+ // lattice. `nbest_size` and `alpha` correspond to parameters `l` and `alpha`
// in https://arxiv.org/abs/1804.10959 (nbest_size < 0 means l = infinity)
//
// - BPE (--model_type=bpe):
- // `alpha` is the dropout probability `p` of bpe merge operations
- // in https://arxiv.org/abs/1910.13267
- // Nbest-based sampling is not supported so nbest_size parameter is ignored in
- // BPE.
+ // `alpha`: The dropout probability `p` of bpe merge operations in
+ // https://arxiv.org/abs/1910.13267 Nbest-based sampling is not supported so
+ // nbest_size parameter is ignored in BPE.
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
float alpha,
std::vector<std::string> *pieces) const;
//////////////////////////////////////////////////////////////
// SampleEncodeAndScore API.
- // Similar to SampleEncode, but returns samples results.
+ //
+ // Sample `samples` many tokenisations from the segmentation lattice.
+ // These methods are only available in model_type=unigram.
+ //
+ // `alpha`: smoothing parameter (inverse temperature). The same as `alpha` in
+ // `Sample` method.
+ // 'wor`: If `wor` is true, the samples are taken without replacement, and the
+ // scores are the inclusion probabilities of the elements in the sample;
+ // otherwise the samples are taken with replacement and the scores are the
+ // log-probs of sample elements
+ // `include_best`: If `include_best` is true, the best tokenisation is always
+ // included in the sample, and the remaining elements are sampled excluding
+ // the best.
virtual util::Status SampleEncodeAndScore(
- absl::string_view input, int num_samples, float theta, bool wor,
+ absl::string_view input, int num_samples, float alpha, bool wor,
bool include_best,
std::vector<std::pair<std::vector<std::string>, float>> *pieces) const;
// Same as above, but returns a sequence of ids.
virtual util::Status SampleEncodeAndScore(
- absl::string_view input, int num_samples, float theta, bool wor,
+ absl::string_view input, int num_samples, float alpha, bool wor,
bool include_best,
std::vector<std::pair<std::vector<int>, float>> *ids) const;
+ //////////////////////////////////////////////////////////////
+ // Entropy API.
+ //
+ // This only available in model_type=unigram.
+ // Calculate entropy of possible tokenisations
+ virtual util::Status CalculateEntropy(absl::string_view input, float alpha,
+ float *entropy) const;
+
//////////////////////////////////////////////////////////////
// Advanced API returning SentencePieceText, which manages
// utf8-byte alignments between user-input/detokenized text
// and internal sentencepiece sequence.
//
// Given a UTF8 input, encodes it into SentencePieceText.
+ //
+ // When using these APIs, sentencepiece.pb.h header files must be included.
+ // We can also use ImutableSentencePieceText as follows.
+ //
+ // ImmutableSentencePieceText spt;
+ // Encode("hello", spt.mutable_proto()).IgnoreError();
+ // std::cout << spt.pieces_size() << std::endl;
virtual util::Status Encode(absl::string_view input,
SentencePieceText *spt) const;
- // Same as above, but returns NBestSentencePieceText.
virtual util::Status NBestEncode(absl::string_view input, int nbest_size,
NBestSentencePieceText *nbest_spt) const;
- // Same as above, but samples one segmentation from the hypotheses
- // (Lattice).
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
float alpha, SentencePieceText *spt) const;
- // Samples N segmentation and returns the scores as well
virtual util::Status SampleEncodeAndScore(
- absl::string_view input, int samples, float theta, bool wor,
+ absl::string_view input, int samples, float alpha, bool wor,
bool include_best, NBestSentencePieceText *samples_spt) const;
- // Calculate entropy of possible tokenisations
- virtual util::Status CalculateEntropy(absl::string_view input, float theta,
- float *entropy) const;
-
- // Given a sequence of pieces, decodes it into SentencePieceText.
- // TODO(taku): Remove this API and use std::vector<std::string_view>
+ // DEPRECATED: Remove this API and use std::vector<std::string_view>
virtual util::Status Decode(const std::vector<std::string> &pieces,
SentencePieceText *spt) const;
- // Given a sequence of pieces, decodes it into SentencePieceText.
virtual util::Status Decode(const std::vector<absl::string_view> &pieces,
SentencePieceText *spt) const;
- // Given a sequence of ids, decodes it into SentencePieceText.
virtual util::Status Decode(const std::vector<int> &ids,
SentencePieceText *spt) const;
- //////////////////////////////////////////////////////////////
- // Handy methods that return the result directly.
- // These functions ignore internal errors.
#ifdef SWIG
-#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
- OutType output; \
- const auto _status = FuncName(__VA_ARGS__, &output); \
- if (!_status.ok()) throw _status; \
- return output;
+#define SPP_SWIG_CHECK_AND_THROW \
+ if (!status.ok()) throw status;
#else
+#define SPP_SWIG_CHECK_AND_THROW \
+ if (!status.ok()) { \
+ }
+#endif // SWIG
+
#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
OutType output; \
- FuncName(__VA_ARGS__, &output).IgnoreError(); \
+ const auto status = FuncName(__VA_ARGS__, &output); \
+ SPP_SWIG_CHECK_AND_THROW; \
+ return output;
+
+#define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...) \
+ OutType output; \
+ const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
+ SPP_SWIG_CHECK_AND_THROW; \
+ return output.SerializeAsString();
+
+#define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...) \
+ OutType output; \
+ const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
+ SPP_SWIG_CHECK_AND_THROW; \
return output;
-#endif
+ //////////////////////////////////////////////////////////////
+ // Handy methods that return the result directly.
+ // These functions ignore internal errors.
virtual std::vector<std::string> EncodeAsPieces(
absl::string_view input) const {
DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<std::string>, input);
virtual std::vector<std::pair<std::vector<std::string>, float>>
SampleEncodeAndScoreAsPieces(absl::string_view input, int num_samples,
- float theta, bool wor, bool include_best) const {
+ float alpha, bool wor, bool include_best) const {
using _T = std::vector<std::pair<std::vector<std::string>, float>>;
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
- theta, wor, include_best);
+ alpha, wor, include_best);
}
virtual std::vector<std::pair<std::vector<int>, float>>
SampleEncodeAndScoreAsIds(absl::string_view input, int num_samples,
- float theta, bool wor, bool include_best) const {
+ float alpha, bool wor, bool include_best) const {
using _T = std::vector<std::pair<std::vector<int>, float>>;
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
- theta, wor, include_best);
+ alpha, wor, include_best);
}
- // TODO(taku): Remove this API and use std::vector<std::string_view>
+ // DEPRECATED: Remove this API and use std::vector<std::string_view>
virtual std::string DecodePieces(
const std::vector<std::string> &pieces) const {
DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, ids);
}
- virtual float CalculateEntropy(absl::string_view text, float theta) const {
- DEFINE_SPP_DIRECT_FUNC_IMPL(CalculateEntropy, float, text, theta);
+ virtual float CalculateEntropy(absl::string_view text, float alpha) const {
+ DEFINE_SPP_DIRECT_FUNC_IMPL(CalculateEntropy, float, text, alpha);
}
-#undef DEFINE_SPP_DIRECT_FUNC_IMPL
-
+ //////////////////////////////////////////////////////////////
+ // SerializedProto API. (DEPRECATED). Use ImmutableProto API.
// They are used in Python interface. Returns serialized proto.
// In python module, we can get access to the full Proto after
// deserialzing the returned byte sequence.
- virtual util::bytes EncodeAsSerializedProto(absl::string_view input) const;
+ virtual util::bytes EncodeAsSerializedProto(absl::string_view input) const {
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
+ }
virtual util::bytes SampleEncodeAsSerializedProto(absl::string_view input,
int nbest_size,
- float alpha) const;
+ float alpha) const {
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
+ input, nbest_size, alpha);
+ }
virtual util::bytes NBestEncodeAsSerializedProto(absl::string_view input,
- int nbest_size) const;
+ int nbest_size) const {
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(
+ NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
+ }
+
+ virtual util::bytes SampleEncodeAndScoreAsSerializedProto(
+ absl::string_view input, int samples, float alpha, bool wor,
+ bool include_best, int nbest_size) const {
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncodeAndScore,
+ ImmutableNBestSentencePieceText, input,
+ samples, alpha, wor, include_best);
+ }
// TODO(taku): Remove this API and use std::vector<std::string_view>
virtual util::bytes DecodePiecesAsSerializedProto(
- const std::vector<std::string> &pieces) const;
+ const std::vector<std::string> &pieces) const {
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
+ pieces);
+ }
virtual util::bytes DecodePiecesAsSerializedProto(
- const std::vector<absl::string_view> &pieces) const;
+ const std::vector<absl::string_view> &pieces) const {
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
+ pieces);
+ }
virtual util::bytes DecodeIdsAsSerializedProto(
- const std::vector<int> &ids) const;
+ const std::vector<int> &ids) const {
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
+ }
+
+ //////////////////////////////////////////////////////////////
+ // ImmutableProto API.
+ virtual ImmutableSentencePieceText EncodeAsImmutableProto(
+ absl::string_view input) const {
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
+ }
+
+ virtual ImmutableSentencePieceText SampleEncodeAsImmutableProto(
+ absl::string_view input, int nbest_size, float alpha) const {
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
+ input, nbest_size, alpha);
+ }
+
+ virtual ImmutableNBestSentencePieceText NBestEncodeAsImmutableProto(
+ absl::string_view input, int nbest_size) const {
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(
+ NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
+ }
+
+ virtual ImmutableNBestSentencePieceText SampleEncodeAndScoreAsImmutableProto(
+ absl::string_view input, int samples, float alpha, bool wor,
+ bool include_best, int nbest_size) const {
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncodeAndScore,
+ ImmutableNBestSentencePieceText, input,
+ samples, alpha, wor, include_best);
+ }
+
+ // TODO(taku): Remove this API and use std::vector<std::string_view>
+ virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
+ const std::vector<std::string> &pieces) const {
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
+ }
+
+ virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
+ const std::vector<absl::string_view> &pieces) const {
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
+ }
+
+ virtual ImmutableSentencePieceText DecodeIdsAsImmutableProto(
+ const std::vector<int> &ids) const {
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
+ }
+
+#undef DEFINE_SPP_DIRECT_FUNC_IMPL
+#undef DEFINE_SPP_SERIALIZED_PROTO_IMPL
+#undef DEFINE_SPP_IMMUTABLE_PROTO_IMPL
//////////////////////////////////////////////////////////////
// Vocabulary management methods.
virtual const std::string &IdToPiece(int id) const;
// Returns the score of `id`.
- // Usually score is an emission log probability of unigram language model.
+ // Usually score is an emission log probability of unigram language
+ // model.
virtual float GetScore(int id) const;
// Returns true if `id` is unknown symbol.
// Allows injection of a normalizer instance. `normalizer` is moved.
void SetNormalizer(std::unique_ptr<normalizer::Normalizer> &&normalizer);
-#endif
+#endif // SWIG
// Returns immutable model proto. Useful to obtain extended
// or experimental parameters encoded in model_proto.
// See the License for the specific language governing permissions and
// limitations under the License.!
+#include "sentencepiece_processor.h"
+
#include <utility>
#include "builder.h"
#include "normalizer.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_model.pb.h"
-#include "sentencepiece_processor.h"
#include "sentencepiece_trainer.h"
#include "testharness.h"
#include "third_party/absl/container/flat_hash_map.h"
int GetPieceSize() const override { return 7; }
int PieceToId(absl::string_view piece) const override {
- static absl::flat_hash_map<absl::string_view, int,
- string_util::string_view_hash>
- kMap = {{"<unk>", 0}, {"<s>", 1}, {"</s>", 2}, {WS "ABC", 3},
- {WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}};
+ static absl::flat_hash_map<absl::string_view, int> kMap = {
+ {"<unk>", 0}, {"<s>", 1}, {"</s>", 2}, {WS "ABC", 3},
+ {WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}};
return port::FindWithDefault(kMap, piece, 0);
}
int GetPieceSize() const override { return 7; }
int PieceToId(absl::string_view piece) const override {
- static absl::flat_hash_map<absl::string_view, int,
- string_util::string_view_hash>
- kMap = {{"<unk>", 0}, {"<s>", 1}, {"</s>", 2}, {WS "ABC", 3},
- {WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}, {WS, 7}};
+ static absl::flat_hash_map<absl::string_view, int> kMap = {
+ {"<unk>", 0}, {"<s>", 1}, {"</s>", 2}, {WS "ABC", 3},
+ {WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}, {WS, 7}};
return port::FindWithDefault(kMap, piece, 0);
}
EXPECT_EQ(2, sp.eos_id());
EXPECT_EQ(-1, sp.pad_id());
- {
- // Verify the default encoder version.
- EXPECT_EQ(EncoderVersion::kOptimized, sp.GetEncoderVersion());
-
- // Set the encoder version to original and verify.
- EXPECT_TRUE(sp.SetEncoderVersion(EncoderVersion::kOriginal).ok());
- EXPECT_EQ(EncoderVersion::kOriginal, sp.GetEncoderVersion());
-
- // Set back to the default encoder version.
- EXPECT_TRUE(sp.SetEncoderVersion(EncoderVersion::kOptimized).ok());
- }
-
{
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {WS, "ab", "c"};
EXPECT_FALSE(sp.IsUnused(6));
EXPECT_FALSE(sp.IsUnused(7));
}
+
+TEST(SentencePieceProcessorTest, ImmutableSentencePieceTextTest) {
+ ImmutableSentencePieceText spt;
+ auto *v = spt.mutable_proto();
+
+ v->set_text("hello world");
+ v->set_score(1.0);
+ for (int i = 0; i < 10; ++i) {
+ auto *p = v->add_pieces();
+ p->set_surface(absl::StrCat("surface_", i));
+ p->set_piece(absl::StrCat("surface_", i));
+ p->set_id(i);
+ p->set_begin(i + 10);
+ p->set_end(i + 20);
+ }
+
+ EXPECT_EQ(v->pieces_size(), spt.pieces_size());
+ for (int i = 0; i < spt.pieces_size(); ++i) {
+ EXPECT_EQ(v->pieces(i).surface(), spt.pieces(i).surface());
+ EXPECT_EQ(v->pieces(i).piece(), spt.pieces(i).piece());
+ EXPECT_EQ(v->pieces(i).id(), spt.pieces(i).id());
+ EXPECT_EQ(v->pieces(i).begin(), spt.pieces(i).begin());
+ EXPECT_EQ(v->pieces(i).end(), spt.pieces(i).end());
+ }
+
+ int n = 0;
+ for (auto &p : spt.pieces()) {
+ EXPECT_EQ(v->pieces(n).surface(), p.surface());
+ EXPECT_EQ(v->pieces(n).piece(), p.piece());
+ EXPECT_EQ(v->pieces(n).id(), p.id());
+ EXPECT_EQ(v->pieces(n).begin(), p.begin());
+ EXPECT_EQ(v->pieces(n).end(), p.end());
+ ++n;
+ }
+
+ EXPECT_EQ(v->text(), spt.text());
+ EXPECT_EQ(v->score(), spt.score());
+ EXPECT_EQ(v->SerializeAsString(), spt.SerializeAsString());
+
+ // test copy.
+ auto spt2 = spt;
+ EXPECT_EQ(spt2.pieces_size(), spt.pieces_size());
+ for (int i = 0; i < spt.pieces_size(); ++i) {
+ EXPECT_EQ(spt2.pieces(i).surface(), spt.pieces(i).surface());
+ EXPECT_EQ(spt2.pieces(i).piece(), spt.pieces(i).piece());
+ EXPECT_EQ(spt2.pieces(i).id(), spt.pieces(i).id());
+ EXPECT_EQ(spt2.pieces(i).begin(), spt.pieces(i).begin());
+ EXPECT_EQ(spt2.pieces(i).end(), spt.pieces(i).end());
+ }
+}
+
+TEST(SentencePieceProcessorTest, ImmutableNBestSentencePieceTextTest) {
+ ImmutableNBestSentencePieceText spt;
+ auto *v = spt.mutable_proto();
+ for (int i = 0; i < 10; ++i) {
+ auto *p = v->add_nbests();
+ p->set_text(absl::StrCat("text_", i));
+ p->set_score(2.0 * i);
+ }
+
+ EXPECT_EQ(v->nbests_size(), spt.nbests_size());
+ for (int i = 0; i < v->nbests_size(); ++i) {
+ EXPECT_EQ(v->nbests(i).text(), spt.nbests(i).text());
+ EXPECT_EQ(v->nbests(i).score(), spt.nbests(i).score());
+ }
+ EXPECT_EQ(v->SerializeAsString(), spt.SerializeAsString());
+
+ // test copy.
+ auto spt2 = spt;
+ EXPECT_EQ(spt2.nbests_size(), spt.nbests_size());
+ EXPECT_EQ(spt2.SerializeAsString(), spt.SerializeAsString());
+}
+
} // namespace sentencepiece
return retval;
}
-std::vector<float> Lattice::ForwardAlgorithm(float theta) const {
+std::vector<float> Lattice::ForwardAlgorithm(float inv_theta) const {
const int len = size();
std::vector<float> alpha(node_allocator_.size(), 0.0);
for (int pos = 0; pos <= len; ++pos) {
for (Node *rnode : begin_nodes_[pos]) {
for (Node *lnode : end_nodes_[pos]) {
- alpha[rnode->node_id] = LogSumExp(
- alpha[rnode->node_id], theta * lnode->score + alpha[lnode->node_id],
- lnode == end_nodes_[pos][0]);
+ alpha[rnode->node_id] =
+ LogSumExp(alpha[rnode->node_id],
+ inv_theta * lnode->score + alpha[lnode->node_id],
+ lnode == end_nodes_[pos][0]);
}
}
}
return alpha;
}
-std::vector<float> Lattice::BackwardAlgorithm(float theta) const {
+std::vector<float> Lattice::BackwardAlgorithm(float inv_theta) const {
const int len = size();
std::vector<float> beta(node_allocator_.size(), 0.0);
return freq * Z;
}
-float Lattice::CalculateEntropy(float theta) const {
+float Lattice::CalculateEntropy(float inv_theta) const {
const int len = size();
// alpha[node_id] is the marginal prob of sequence up to start of node
// H is entropy of sequence
// the index of alpha/H is Node::node_id.
- std::vector<float> alpha(node_allocator_.size(), 0.0);
std::vector<float> H(node_allocator_.size(), 0.0);
// Populate the forward marginals to get the normalising constant
- alpha = ForwardAlgorithm(theta);
+ const auto alpha = ForwardAlgorithm(inv_theta);
// Now populate the forward entropies
for (int pos = 0; pos <= len; ++pos) {
// We have to normalise p(lnode) by the marginal contribution it makes
const float lnode_transition_prob =
- ((theta * lnode->score) + alpha[lnode->node_id] -
+ ((inv_theta * lnode->score) + alpha[lnode->node_id] -
alpha[rnode->node_id]);
H[rnode->node_id] += std::exp(lnode_transition_prob) *
(H[lnode->node_id] + lnode_transition_prob);
std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
bool sample,
- float theta) {
+ float inv_theta) {
if (nbest_size < 1) {
LOG(WARNING) << "nbest_size >= 1. Returns empty result.";
return {};
if (sample) {
// Run forwards algorithm to get normalising constants
- alpha = ForwardAlgorithm(theta);
+ alpha = ForwardAlgorithm(inv_theta);
// f(eos) = Gumbel(0), as it is the perturbed score of the entire lattice.
eos->fx = Gumbel();
} else {
for (int i = 0; i < end_nodes(node->pos).size(); i++) {
Node *lnode = end_nodes(node->pos)[i];
// Calculate backwards transition score
- probs[i] = top->gx + alpha[lnode->node_id] + (theta * lnode->score) - Z;
+ probs[i] =
+ top->gx + alpha[lnode->node_id] + (inv_theta * lnode->score) - Z;
perturbed_probs[i] = probs[i] + Gumbel();
if (perturbed_probs[i] > max_score) {
max_score = perturbed_probs[i];
return results;
}
-std::vector<Lattice::Node *> Lattice::Sample(float theta) {
+std::vector<Lattice::Node *> Lattice::Sample(float inv_theta) {
const int len = size();
if (len == 0) return {};
std::vector<float> alpha(node_allocator_.size(), 0.0);
- alpha = ForwardAlgorithm(theta);
+ alpha = ForwardAlgorithm(inv_theta);
auto *mt = random::GetRandomGenerator();
while (true) {
probs.clear();
for (const Node *lnode : end_nodes_[node->pos]) {
- probs.push_back(std::exp(static_cast<double>(alpha[lnode->node_id] +
- theta * lnode->score - Z)));
+ probs.push_back(std::exp(static_cast<double>(
+ alpha[lnode->node_id] + inv_theta * lnode->score - Z)));
}
std::discrete_distribution<int> dist(probs.begin(), probs.end());
node = end_nodes_[node->pos][dist(*mt)];
}
EncodeResult Model::SampleEncode(absl::string_view normalized,
- float theta) const {
+ float inv_theta) const {
if (!status().ok() || normalized.empty()) {
return {};
}
PopulateNodes(&lattice);
EncodeResult results;
- for (const auto *node : lattice.Sample(theta)) {
+ for (const auto *node : lattice.Sample(inv_theta)) {
results.emplace_back(node->piece, node->id);
}
}
NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized,
- float theta, int samples,
+ float inv_theta, int samples,
bool wor,
bool include_best) const {
if (!status().ok() || normalized.empty()) {
lattice.SetSentence(normalized);
PopulateNodes(&lattice);
- std::vector<float> alpha = lattice.ForwardAlgorithm(theta);
- float marginal = alpha[lattice.eos_node()->node_id];
+ const std::vector<float> alpha = lattice.ForwardAlgorithm(inv_theta);
+ const float marginal = alpha[lattice.eos_node()->node_id];
if (include_best) {
if (!wor) {
- LOG(FATAL) << "include_best not supported for wor false";
+ LOG(ERROR) << "include_best not supported for wor false";
+ return {};
}
EncodeResult result;
- Lattice::LatticePathWithScore best_path = lattice.Viterbi();
-
+ const auto best_path = lattice.Viterbi();
for (const auto *node : best_path.first) {
result.emplace_back(node->piece, node->id);
}
if (wor) {
// Draw k+1 samples as we need perturbed score of k+1th element
- std::vector<Lattice::LatticePathWithScore> nbest_samples =
- lattice.NBest(samples + 1, true, theta);
+ auto nbest_samples = lattice.NBest(samples + 1, true, inv_theta);
if (include_best) {
std::vector<std::vector<Lattice::Node *>> nbest_paths(
nbest_paths[i] = nbest_samples[i].first;
}
// Remove the best result from the samples if necessary
- Lattice::LatticePathWithScore best_path = lattice.Viterbi();
+ const auto best_path = lattice.Viterbi();
const int index_of_best =
(std::find(nbest_paths.begin(), nbest_paths.end(), best_path.first) -
nbest_paths.begin());
if (index_of_best != nbest_samples.size()) {
- LOG(INFO) << "removing best path from samples";
nbest_samples.erase(nbest_samples.begin() + index_of_best);
} else {
nbest_samples.pop_back();
float score = 0.0;
for (const auto *node : nbest.first) {
- score += (theta * node->score);
+ score += (inv_theta * node->score);
result.emplace_back(node->piece, node->id);
}
for (auto &it : results) {
// Only modify non best sample inclusion probabilities.
if (it.second != 0.0) {
- double x = it.second - kappa;
- double y = std::exp(x);
+ const double x = it.second - kappa;
+ const double y = std::exp(x);
double inclusion_prob;
if (x <= -10) {
// Series expansion of the log Gumbel survival function up to eps.
float score = 0.0;
EncodeResult result;
- std::vector<Lattice::Node *> sample = lattice.Sample(theta);
+ const std::vector<Lattice::Node *> sample = lattice.Sample(inv_theta);
for (const auto *node : sample) {
result.emplace_back(node->piece, node->id);
- score += (theta * node->score);
+ score += (inv_theta * node->score);
}
results.emplace_back(result, score - marginal);
}
return results;
}
-float Model::CalculateEntropy(absl::string_view normalized, float theta) const {
+float Model::CalculateEntropy(absl::string_view normalized,
+ float inv_theta) const {
Lattice lattice;
lattice.SetSentence(normalized);
PopulateNodes(&lattice);
- return lattice.CalculateEntropy(theta);
+ return lattice.CalculateEntropy(inv_theta);
}
bool Model::VerifyOutputsEquivalent(absl::string_view expected,
bool VerifyOutputsEquivalent(absl::string_view expected,
absl::string_view actual) const override;
+ enum EncoderVersion {
+ kOptimized, // The optimized encoder.
+ kOriginal // The original encoder.
+ };
+
+ void SetEncoderVersion(EncoderVersion encoder_version) {
+ encoder_version_ = encoder_version;
+ }
+
+ // Returns the current encoder version in use.
+ EncoderVersion GetEncoderVersion() const { return encoder_version_; }
+
protected:
// Builds a Trie index.
void BuildTrie(std::vector<std::pair<absl::string_view, int>> *pieces);
// Maximum size of the return value of Trie, which corresponds
// to the maximum size of shared common prefix in the sentence pieces.
int trie_results_size_;
+
+ // encoder version.
+ EncoderVersion encoder_version_ = kOptimized;
};
} // namespace unigram
// See the License for the specific language governing permissions and
// limitations under the License.!
+#include "unigram_model.h"
+
#include <cmath>
#include <map>
#include <string>
#include "testharness.h"
#include "third_party/absl/strings/str_cat.h"
#include "third_party/absl/strings/str_join.h"
-#include "unigram_model.h"
#include "util.h"
namespace sentencepiece {
// Calculate expected probabilities of each path
// Note that sampling without replacement affects the expected frequencies!
- const std::vector<double> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
- for (const auto theta : kTheta) {
+ const std::vector<double> kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0};
+ for (const auto inv_theta : kInv_Theta) {
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
std::map<std::string, float> probs;
- probs["ABC"] = std::exp(theta * 1.0);
- probs["AB C"] = std::exp(theta * (0.2 + 0.1));
- probs["A BC"] = std::exp(theta * (0.0 + 0.5));
- probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
+ probs["ABC"] = std::exp(inv_theta * 1.0);
+ probs["AB C"] = std::exp(inv_theta * (0.2 + 0.1));
+ probs["A BC"] = std::exp(inv_theta * (0.0 + 0.5));
+ probs["A B C"] = std::exp(inv_theta * (0.0 + 0.0 + 0.1));
for (const auto &it : strings) {
EXPECT_EQ(1, probs.count(it));
for (const auto num_samples : kNumSamples) {
std::map<std::string, int> counts;
for (int i = 0; i < kTrials; i++) {
- auto nbests = lattice.NBest(num_samples, true, theta);
+ auto nbests = lattice.NBest(num_samples, true, inv_theta);
for (const auto &nbest : nbests) {
counts[GetTokenized(nbest.first)]++;
}
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
// Calculate expected probabilities of each path
- const std::vector<double> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
- for (const auto theta : kTheta) {
+ const std::vector<double> kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0};
+ for (const auto inv_theta : kInv_Theta) {
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
std::map<std::string, float> probs;
- probs["ABC"] = std::exp(theta * 1.0);
- probs["AB C"] = std::exp(theta * (0.2 + 0.1));
- probs["A BC"] = std::exp(theta * (0.0 + 0.5));
- probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
+ probs["ABC"] = std::exp(inv_theta * 1.0);
+ probs["AB C"] = std::exp(inv_theta * (0.2 + 0.1));
+ probs["A BC"] = std::exp(inv_theta * (0.0 + 0.5));
+ probs["A B C"] = std::exp(inv_theta * (0.0 + 0.0 + 0.1));
double Z = 0.0;
for (const auto &it : probs) Z += it.second;
for (const auto &it : probs) {
entropy += (it.second * std::log(it.second));
}
- EXPECT_NEAR(-entropy, lattice.CalculateEntropy(theta), 0.02);
+ EXPECT_NEAR(-entropy, lattice.CalculateEntropy(inv_theta), 0.02);
}
}
InsertWithScore(&lattice, 1, 2, 0.5); // BC
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
- const std::vector<float> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
- for (const auto theta : kTheta) {
- std::vector<float> alpha = lattice.ForwardAlgorithm(theta);
+ const std::vector<float> kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0};
+ for (const auto inv_theta : kInv_Theta) {
+ std::vector<float> alpha = lattice.ForwardAlgorithm(inv_theta);
EXPECT_EQ(alpha.size(), 8); // 6 nodes, plus BOS, EOS
// only alpha[C], alpha[EOS] have non-zero alpha
for (int i : {0, 1, 2, 3}) {
if (i < 2) {
EXPECT_EQ(alpha[node->node_id], 0.0);
} else if (i == 2) {
- float Z =
- std::log(std::exp(theta * (0.0 + 0.0)) + std::exp(theta * 0.2));
+ float Z = std::log(std::exp(inv_theta * (0.0 + 0.0)) +
+ std::exp(inv_theta * 0.2));
EXPECT_EQ(alpha[node->node_id], Z);
} else if (i == 3) {
- float Z = std::log(std::exp(theta * (0.0 + 0.0 + 0.1)) + // A + B + C
- std::exp(theta * (0.2 + 0.1)) + // AB + C
- std::exp(theta * (0.0 + 0.5)) + // A + BC
- std::exp(theta * 1.0)); // ABC
+ float Z =
+ std::log(std::exp(inv_theta * (0.0 + 0.0 + 0.1)) + // A + B + C
+ std::exp(inv_theta * (0.2 + 0.1)) + // AB + C
+ std::exp(inv_theta * (0.0 + 0.5)) + // A + BC
+ std::exp(inv_theta * 1.0)); // ABC
EXPECT_EQ(Z, alpha[node->node_id]);
}
}
InsertWithScoreAndId(&lattice, 1, 2, 1.7, 4); // BC
InsertWithScoreAndId(&lattice, 0, 3, 1.8, 5); // ABC
- const std::vector<double> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
- for (int i = 0; i < kTheta.size(); ++i) {
+ const std::vector<double> kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0};
+ for (int i = 0; i < kInv_Theta.size(); ++i) {
std::map<std::string, double> probs;
// Expands all paths in the lattice.
- probs["A B C"] = exp(kTheta[i] * (1.0 + 1.2 + 1.5)); // A B C
- probs["AB C"] = exp(kTheta[i] * (1.6 + 1.5)); // AB C
- probs["A BC"] = exp(kTheta[i] * (1.0 + 1.7)); // A BC
- probs["ABC"] = exp(kTheta[i] * 1.8); // ABC
+ probs["A B C"] = exp(kInv_Theta[i] * (1.0 + 1.2 + 1.5)); // A B C
+ probs["AB C"] = exp(kInv_Theta[i] * (1.6 + 1.5)); // AB C
+ probs["A BC"] = exp(kInv_Theta[i] * (1.0 + 1.7)); // A BC
+ probs["ABC"] = exp(kInv_Theta[i] * 1.8); // ABC
// Computes expected probabilities.
double Z = 0.0;
constexpr int kTrial = 100000;
std::map<std::string, int> freq;
for (int n = 0; n < kTrial; ++n) {
- freq[GetTokenized(lattice.Sample(kTheta[i]))]++;
+ freq[GetTokenized(lattice.Sample(kInv_Theta[i]))]++;
}
EXPECT_EQ(probs.size(), freq.size());
}
// Returns model protos in parameterized tests.
-const std::vector<EncoderVersion> &GetEncoderVersions() {
- static const std::vector<EncoderVersion> &v =
- *new std::vector<EncoderVersion>{EncoderVersion::kOptimized,
- EncoderVersion::kOriginal};
+const std::vector<Model::EncoderVersion> &GetEncoderVersions() {
+ static const std::vector<Model::EncoderVersion> &v =
+ *new std::vector<Model::EncoderVersion>{Model::kOptimized,
+ Model::kOriginal};
return v;
}
-class UnigramModelTest : public test::TestWithParam<EncoderVersion> {
+class UnigramModelTest : public test::TestWithParam<Model::EncoderVersion> {
protected:
void SetUp() override { encoder_version_ = GetParam(); }
void TearDown() override {}
- EncoderVersion encoder_version_;
+ Model::EncoderVersion encoder_version_;
};
void AddPiece(ModelProto *model_proto, const std::string &piece,
lattice.SetSentence("ABC");
model.PopulateNodes(&lattice);
- std::vector<float> kTheta = {0.0, 1.0};
+ std::vector<float> kInv_Theta = {0.0, 1.0};
- for (const auto theta : kTheta) {
+ for (const auto inv_theta : kInv_Theta) {
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
std::map<std::string, float> probs;
- probs["ABC"] = std::exp(theta * 1.0);
- probs["AB C"] = std::exp(theta * (0.2 + 0.1));
- probs["A BC"] = std::exp(theta * (0.0 + 0.5));
- probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
+ probs["ABC"] = std::exp(inv_theta * 1.0);
+ probs["AB C"] = std::exp(inv_theta * (0.2 + 0.1));
+ probs["A BC"] = std::exp(inv_theta * (0.0 + 0.5));
+ probs["A B C"] = std::exp(inv_theta * (0.0 + 0.0 + 0.1));
for (const auto &it : strings) {
EXPECT_EQ(1, probs.count(it));
std::map<std::string, float> scores;
int kTrials = 50000;
for (int i = 0; i < kTrials; i++) {
- NBestEncodeResult sample =
- model.SampleEncodeAndScore("ABC", theta, num_samples, true, false);
+ NBestEncodeResult sample = model.SampleEncodeAndScore(
+ "ABC", inv_theta, num_samples, true, false);
for (const auto &it : sample) {
std::vector<std::string> tokens;
AddPiece(&model_proto, "d", 0.4);
Model model(model_proto);
- EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+ model.SetEncoderVersion(encoder_version_);
EXPECT_EQ(model_proto.SerializeAsString(),
model.model_proto().SerializeAsString());
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "x");
Model model(model_proto);
- EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+ model.SetEncoderVersion(encoder_version_);
Lattice lattice;
lattice.SetSentence("abc");
AddPiece(&model_proto, "bc", 0.4); // 6
Model model(model_proto);
- EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+ model.SetEncoderVersion(encoder_version_);
Lattice lattice;
lattice.SetSentence("abc");
model_proto.mutable_pieces(6)->set_type(ModelProto::SentencePiece::UNUSED);
Model model(model_proto);
- EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+ model.SetEncoderVersion(encoder_version_);
Lattice lattice;
lattice.SetSentence("abc");
AddPiece(&model_proto, "abc", 10.0); // 8
Model model(model_proto);
- EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+ model.SetEncoderVersion(encoder_version_);
auto nbest = model.NBestEncode("", 10);
EXPECT_EQ(1, nbest.size());
ModelProto::SentencePiece::USER_DEFINED);
Model model(model_proto);
- EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+ model.SetEncoderVersion(encoder_version_);
EncodeResult result;
// No unused.
{
Model model(model_proto);
- EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+ model.SetEncoderVersion(encoder_version_);
const auto result = model.Encode("abcd");
EXPECT_EQ(1, result.size());
EXPECT_EQ("abcd", result[0].first);
{
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
Model model(model_proto);
- EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+ model.SetEncoderVersion(encoder_version_);
const auto result = model.Encode("abcd");
EXPECT_EQ(2, result.size());
EXPECT_EQ("abc", result[0].first);
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::UNUSED);
Model model(model_proto);
- EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+ model.SetEncoderVersion(encoder_version_);
const auto result = model.Encode("abcd");
EXPECT_EQ(2, result.size());
EXPECT_EQ("abc", result[0].first);
model_proto.mutable_pieces(4)->set_type(ModelProto::SentencePiece::UNUSED);
model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::NORMAL);
Model model(model_proto);
- EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+ model.SetEncoderVersion(encoder_version_);
const auto result = model.Encode("abcd");
EXPECT_EQ(2, result.size());
EXPECT_EQ("ab", result[0].first);
AddPiece(&model_proto, "c", 2.0); // 9
AddPiece(&model_proto, "d", 1.0); // 10
Model model(model_proto);
- EXPECT_TRUE(model.SetEncoderVersion(encoder_version_).ok());
+ model.SetEncoderVersion(encoder_version_);
// Equivalent outputs.
EXPECT_TRUE(model.VerifyOutputsEquivalent("", ""));
EXPECT_TRUE(model.VerifyOutputsEquivalent("a b", "a b"));
// String utilities
namespace string_util {
-struct string_view_hash {
- // DJB hash function.
- inline size_t operator()(const absl::string_view &sp) const {
- size_t hash = 5381;
- for (size_t i = 0; i < sp.size(); ++i) {
- hash = ((hash << 5) + hash) + sp[i];
- }
- return hash;
- }
-};
-
template <typename Target>
inline bool lexical_cast(absl::string_view arg, Target *result) {
std::stringstream ss;