Fixed test failure.
authorTaku Kudo <taku@google.com>
Wed, 3 Aug 2022 08:20:01 +0000 (17:20 +0900)
committerKentaro Hayashi <kenhys@xdump.org>
Mon, 21 Nov 2022 13:43:46 +0000 (13:43 +0000)
Signed-off-by: Kentaro Hayashi <kenhys@gmail.com>
Gbp-Pq: Name 0016-Fixed-test-failure.patch

python/src/sentencepiece/sentencepiece.i
python/src/sentencepiece/sentencepiece_wrap.cxx
src/sentencepiece_processor.cc
src/sentencepiece_processor.h

index 75f62c80c73bdabc3505a7bae6557bbe04feff0d..1a94fef0393b35edc3c02fd85760d1b8558f7117 100644 (file)
@@ -193,6 +193,19 @@ inline void CheckIds(const std::vector<int> &ids, int num_pieces) {
 
 inline void CheckIds(const std::vector<absl::string_view> &ids, int num_pieces) {}
 
+template <typename T>
+inline void ConvertToUnicodeSpans(T *proto) {}
+
+template <>
+inline void ConvertToUnicodeSpans(sentencepiece::ImmutableSentencePieceText *proto) {
+  proto->ConvertToUnicodeSpans();
+}
+
+template <>
+inline void ConvertToUnicodeSpans(sentencepiece::ImmutableNBestSentencePieceText *proto) {
+  proto->ConvertToUnicodeSpans();
+}
+
 class ThreadPool {
  public:
   explicit ThreadPool(size_t request_size) :
@@ -239,6 +252,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
                        self->FuncName(ins[i]);                          \
             RewriteIds(*self, &out, add_bos, add_eos, reverse,          \
                        emit_unk_piece);                                 \
+            ConvertToUnicodeSpans(&out);                                \
             outs[i] = std::move(out);                                   \
           }                                                             \
         });                                                             \
@@ -255,7 +269,9 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
       pool.Schedule([&, n]() {                                          \
           for (size_t i = n; i < ins.size(); i += num_threads) {        \
             CheckIds(ins[i], self->GetPieceSize());                     \
-            outs[i] = self->FuncName(ins[i]);                           \
+            auto out = self->FuncName(ins[i]);                          \
+            ConvertToUnicodeSpans(&out);                                \
+            outs[i] = std::move(out);                                   \
           }                                                             \
         });                                                             \
     }                                                                   \
@@ -396,6 +412,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
     auto proto = enable_sampling ?
                  $self->SampleEncodeAsImmutableProto(text, nbest_size, alpha) :
                  $self->EncodeAsImmutableProto(text);
+    proto.ConvertToUnicodeSpans();
     RewriteIds(*$self, &proto, add_bos, add_eos, reverse, emit_unk_piece);
     return proto;
   }
@@ -467,13 +484,17 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
   sentencepiece::ImmutableSentencePieceText _DecodeIdsAsImmutableProto(
       const std::vector<int> &ids) const {
     CheckIds(ids, $self->GetPieceSize());
-    return $self->DecodeIdsAsImmutableProto(ids);
+    auto proto = $self->DecodeIdsAsImmutableProto(ids);
+    proto.ConvertToUnicodeSpans();
+    return proto;
   }
 
   sentencepiece::ImmutableSentencePieceText _DecodePiecesAsImmutableProto(
       const std::vector<absl::string_view> &pieces) const {
     CheckIds(pieces, $self->GetPieceSize());
-    return $self->DecodePiecesAsImmutableProto(pieces);
+    auto proto= $self->DecodePiecesAsImmutableProto(pieces);
+    proto.ConvertToUnicodeSpans();
+    return proto;
   }
 
   /////////////////////////////////////////////////////////////////////////////
@@ -557,7 +578,9 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
                                    bool emit_unk_piece) const {
     RewriteIds(*$self, static_cast<sentencepiece::ImmutableSentencePieceText *>(nullptr),
                add_bos, add_eos, reverse, emit_unk_piece);
-    return $self->NBestEncodeAsImmutableProto(text, nbest_size);
+    auto proto = $self->NBestEncodeAsImmutableProto(text, nbest_size);
+    proto.ConvertToUnicodeSpans();
+    return proto;
   }
 
 
@@ -611,8 +634,10 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
                                             bool emit_unk_piece) const {
     RewriteIds(*$self, static_cast<sentencepiece::util::bytes *>(nullptr),
                add_bos, add_eos, reverse, emit_unk_piece);
-    return $self->SampleEncodeAndScoreAsImmutableProto(text, num_samples,
+    auto proto = $self->SampleEncodeAndScoreAsImmutableProto(text, num_samples,
                                                        alpha, wor, include_best);
+    proto.ConvertToUnicodeSpans();
+    return proto;
   }
 
 
index 22e0708771f6807e9c69597f31e76098745d5ffe..4b8b5ef122f32aab3cdafcb64481706b980fd2e0 100644 (file)
@@ -3002,6 +3002,19 @@ inline void CheckIds(const std::vector<int> &ids, int num_pieces) {
 
 inline void CheckIds(const std::vector<absl::string_view> &ids, int num_pieces) {}
 
+template <typename T>
+inline void ConvertToUnicodeSpans(T *proto) {}
+
+template <>
+inline void ConvertToUnicodeSpans(sentencepiece::ImmutableSentencePieceText *proto) {
+  proto->ConvertToUnicodeSpans();
+}
+
+template <>
+inline void ConvertToUnicodeSpans(sentencepiece::ImmutableNBestSentencePieceText *proto) {
+  proto->ConvertToUnicodeSpans();
+}
+
 class ThreadPool {
  public:
   explicit ThreadPool(size_t request_size) :
@@ -3048,6 +3061,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
                        self->FuncName(ins[i]);                          \
             RewriteIds(*self, &out, add_bos, add_eos, reverse,          \
                        emit_unk_piece);                                 \
+            ConvertToUnicodeSpans(&out);                                \
             outs[i] = std::move(out);                                   \
           }                                                             \
         });                                                             \
@@ -3064,7 +3078,9 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
       pool.Schedule([&, n]() {                                          \
           for (size_t i = n; i < ins.size(); i += num_threads) {        \
             CheckIds(ins[i], self->GetPieceSize());                     \
-            outs[i] = self->FuncName(ins[i]);                           \
+            auto out = self->FuncName(ins[i]);                          \
+            ConvertToUnicodeSpans(&out);                                \
+            outs[i] = std::move(out);                                   \
           }                                                             \
         });                                                             \
     }                                                                   \
@@ -3540,6 +3556,7 @@ SWIGINTERN sentencepiece::ImmutableSentencePieceText sentencepiece_SentencePiece
     auto proto = enable_sampling ?
                  self->SampleEncodeAsImmutableProto(text, nbest_size, alpha) :
                  self->EncodeAsImmutableProto(text);
+    proto.ConvertToUnicodeSpans();
     RewriteIds(*self, &proto, add_bos, add_eos, reverse, emit_unk_piece);
     return proto;
   }
@@ -3578,11 +3595,15 @@ SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor__Deco
   }
 SWIGINTERN sentencepiece::ImmutableSentencePieceText sentencepiece_SentencePieceProcessor__DecodeIdsAsImmutableProto(sentencepiece::SentencePieceProcessor const *self,std::vector< int > const &ids){
     CheckIds(ids, self->GetPieceSize());
-    return self->DecodeIdsAsImmutableProto(ids);
+    auto proto = self->DecodeIdsAsImmutableProto(ids);
+    proto.ConvertToUnicodeSpans();
+    return proto;
   }
 SWIGINTERN sentencepiece::ImmutableSentencePieceText sentencepiece_SentencePieceProcessor__DecodePiecesAsImmutableProto(sentencepiece::SentencePieceProcessor const *self,std::vector< absl::string_view > const &pieces){
     CheckIds(pieces, self->GetPieceSize());
-    return self->DecodePiecesAsImmutableProto(pieces);
+    auto proto= self->DecodePiecesAsImmutableProto(pieces);
+    proto.ConvertToUnicodeSpans();
+    return proto;
   }
 SWIGINTERN std::vector< std::string > sentencepiece_SentencePieceProcessor__DecodeIdsBatch(sentencepiece::SentencePieceProcessor const *self,std::vector< std::vector< int > > const &ins,int num_threads){
     DEFINE_DECODE_BATCH_FUNC_IMPL(DecodeIds, int, std::string);
@@ -3628,7 +3649,9 @@ SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor__NBes
 SWIGINTERN sentencepiece::ImmutableNBestSentencePieceText sentencepiece_SentencePieceProcessor__NBestEncodeAsImmutableProto(sentencepiece::SentencePieceProcessor const *self,absl::string_view text,int nbest_size,bool add_bos,bool add_eos,bool reverse,bool emit_unk_piece){
     RewriteIds(*self, static_cast<sentencepiece::ImmutableSentencePieceText *>(nullptr),
                add_bos, add_eos, reverse, emit_unk_piece);
-    return self->NBestEncodeAsImmutableProto(text, nbest_size);
+    auto proto = self->NBestEncodeAsImmutableProto(text, nbest_size);
+    proto.ConvertToUnicodeSpans();
+    return proto;
   }
 SWIGINTERN std::vector< std::pair< std::vector< int >,float > > sentencepiece_SentencePieceProcessor__SampleEncodeAndScoreAsIds(sentencepiece::SentencePieceProcessor const *self,absl::string_view text,int num_samples,float alpha,bool wor,bool include_best,bool add_bos,bool add_eos,bool reverse,bool emit_unk_piece){
     auto idss = self->SampleEncodeAndScoreAsIds(text, num_samples,
@@ -3655,8 +3678,10 @@ SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor__Samp
 SWIGINTERN sentencepiece::ImmutableNBestSentencePieceText sentencepiece_SentencePieceProcessor__SampleEncodeAndScoreAsImmutableProto(sentencepiece::SentencePieceProcessor const *self,absl::string_view text,int num_samples,float alpha,bool wor,bool include_best,bool add_bos,bool add_eos,bool reverse,bool emit_unk_piece){
     RewriteIds(*self, static_cast<sentencepiece::util::bytes *>(nullptr),
                add_bos, add_eos, reverse, emit_unk_piece);
-    return self->SampleEncodeAndScoreAsImmutableProto(text, num_samples,
+    auto proto = self->SampleEncodeAndScoreAsImmutableProto(text, num_samples,
                                                        alpha, wor, include_best);
+    proto.ConvertToUnicodeSpans();
+    return proto;
   }
 SWIGINTERN float sentencepiece_SentencePieceProcessor__CalculateEntropy(sentencepiece::SentencePieceProcessor *self,absl::string_view text,float alpha){
     return self->CalculateEntropy(text, alpha);
index 2a5c39932ab8698889a1d66ac9b244faf5182017..f0df2f601a2d4d6b36858901475cd810c341ae60 100644 (file)
@@ -56,14 +56,14 @@ std::vector<absl::string_view> ToPieceArray(const std::vector<std::string> &v) {
 }
 
 void ConvertToUnicodeSpansInternal(SentencePieceText *spt) {
-  if (spt == nullptr) return;
+  if (spt == nullptr || spt->text().empty()) return;
 
   std::vector<int> utf8_to_unicode(spt->text().size() + 1, 0);
   absl::string_view str = spt->text();
   size_t prev = 0;
   int ulen = 0;
   while (!str.empty()) {
-    const size_t mblen = string_util::OneCharLen(str.data());
+    const size_t mblen = std::max<int>(1, string_util::OneCharLen(str.data()));
     for (int i = prev; i < prev + mblen; ++i) {
       utf8_to_unicode[i] = ulen;
     }
index be9449e5faa5f227741d91b941db5d13cc43a246..14b1e8cd830c7039af9d5895103dfdabd8cc6ceb 100644 (file)
@@ -419,47 +419,33 @@ class SentencePieceProcessor {
 
   virtual util::Status Decode(const std::vector<int> &ids,
                               SentencePieceText *spt) const;
-
-#ifndef SWIGPYTHON
-
-#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
-  OutType output;                                           \
-  const auto status = FuncName(__VA_ARGS__, &output);       \
-  return output;
-
-#define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...)     \
-  OutType output;                                                    \
-  const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
-  return output.SerializeAsString();
-
-#define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...)      \
-  OutType output;                                                    \
-  const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
-  return output;
-
+#ifdef SWIG
+#define SPP_SWIG_CHECK_AND_THROW \
+  if (!status.ok()) throw status;
 #else
+#define SPP_SWIG_CHECK_AND_THROW \
+  if (!status.ok()) {            \
+  }
+#endif  // SWIG
 
 #define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
   OutType output;                                           \
   const auto status = FuncName(__VA_ARGS__, &output);       \
-  if (!status.ok()) throw status;                           \
+  SPP_SWIG_CHECK_AND_THROW;                                \
   return output;
 
 #define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...)     \
   OutType output;                                                    \
   const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
-  if (!status.ok()) throw status;                                    \
+  SPP_SWIG_CHECK_AND_THROW;                                         \
   return output.SerializeAsString();
 
 #define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...)      \
   OutType output;                                                    \
   const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
-  if (!status.ok()) throw status;                                    \
-  output.ConvertToUnicodeSpans();                                    \
+  SPP_SWIG_CHECK_AND_THROW;                                         \
   return output;
 
-#endif  // SWIGPYTHON
-
   //////////////////////////////////////////////////////////////
   // Handy methods that return the result directly.
   // These functions ignore internal errors.