Adds more unittests
authorTaku Kudo <taku@google.com>
Tue, 2 Aug 2022 17:24:53 +0000 (02:24 +0900)
committerKentaro Hayashi <kenhys@xdump.org>
Mon, 21 Nov 2022 13:43:46 +0000 (13:43 +0000)
Signed-off-by: Kentaro Hayashi <kenhys@gmail.com>
Gbp-Pq: Name 0013-Adds-more-unittests.patch

python/src/sentencepiece/__init__.py
python/src/sentencepiece/sentencepiece.i
python/src/sentencepiece/sentencepiece_wrap.cxx
python/test/sentencepiece_test.py
src/sentencepiece_processor.cc
src/sentencepiece_processor.h
src/sentencepiece_processor_test.cc

index 69a9825821e894ac4834984ea6370ad3233a3aef..07acb940688fda2531ebc1a6b6b68194c8976edc 100644 (file)
@@ -98,6 +98,9 @@ class ImmutableSentencePieceText(object):
     def pieces_size(self):
         return _sentencepiece.ImmutableSentencePieceText_pieces_size(self)
 
+    def pieces(self, index):
+        return _sentencepiece.ImmutableSentencePieceText_pieces(self, index)
+
     def text(self):
         return _sentencepiece.ImmutableSentencePieceText_text(self)
 
@@ -107,18 +110,24 @@ class ImmutableSentencePieceText(object):
     def SerializeAsString(self):
         return _sentencepiece.ImmutableSentencePieceText_SerializeAsString(self)
 
-    def pieces(self, index):
-        return _sentencepiece.ImmutableSentencePieceText_pieces(self, index)
+    def _pieces(self, index):
+        return _sentencepiece.ImmutableSentencePieceText__pieces(self, index)
+
+    def pieces(self, i):
+      return self._pieces(i)
 
     def __len__(self):
       return self.pieces_size()
 
     def __getitem__(self, i):
-      return self.pieces(i)
+      return self._pieces(i)
 
     def __eq__(self, other):
       return self.SerializeAsString() == other.SerializeAsString()
 
+    def __hash__(self):
+      return hash(self.SerializeAsString())        
+
 
 # Register ImmutableSentencePieceText in _sentencepiece:
 _sentencepiece.ImmutableSentencePieceText_swigregister(ImmutableSentencePieceText)
@@ -134,21 +143,30 @@ class ImmutableNBestSentencePieceText(object):
     def nbests_size(self):
         return _sentencepiece.ImmutableNBestSentencePieceText_nbests_size(self)
 
+    def nbests(self, index):
+        return _sentencepiece.ImmutableNBestSentencePieceText_nbests(self, index)
+
     def SerializeAsString(self):
         return _sentencepiece.ImmutableNBestSentencePieceText_SerializeAsString(self)
 
-    def nbests(self, index):
-        return _sentencepiece.ImmutableNBestSentencePieceText_nbests(self, index)
+    def _nbests(self, index):
+        return _sentencepiece.ImmutableNBestSentencePieceText__nbests(self, index)
+
+    def __nbests__(self, i):
+      return self._nbests(i)
 
     def __len__(self):
       return self.nbests_size()
 
     def __getitem__(self, i):
-      return self.nbests(i)
+      return self._nbests(i)
 
     def __eq__(self, other):
       return self.SerializeAsString() == other.SerializeAsString()
 
+    def __hash__(self):
+      return hash(self.SerializeAsString())
+
 
 # Register ImmutableNBestSentencePieceText in _sentencepiece:
 _sentencepiece.ImmutableNBestSentencePieceText_swigregister(ImmutableNBestSentencePieceText)
@@ -272,6 +290,9 @@ class SentencePieceProcessor(object):
     def _DecodeIdsAsSerializedProtoBatch(self, ins, num_threads):
         return _sentencepiece.SentencePieceProcessor__DecodeIdsAsSerializedProtoBatch(self, ins, num_threads)
 
+    def _DecodeIdsAsImmutableProtoBatch(self, ins, num_threads):
+        return _sentencepiece.SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch(self, ins, num_threads)
+
     def _DecodePiecesBatch(self, ins, num_threads):
         return _sentencepiece.SentencePieceProcessor__DecodePiecesBatch(self, ins, num_threads)
 
@@ -539,6 +560,8 @@ class SentencePieceProcessor(object):
           return self._NBestEncodeAsImmutableProto(text, nbest_size,
                                                    add_bos, add_eos, reverse, emit_unk_piece)
 
+        raise RuntimeError('unknown out_type')
+
       if type(input) is list:
         return [_encode(n) for n in input]
 
@@ -621,10 +644,21 @@ class SentencePieceProcessor(object):
         if out_type is int:
           return self._SampleEncodeAndScoreAsIds(text, num_samples, alpha, wor, include_best,
                                                  add_bos, add_eos, reverse, emit_unk_piece)
-        else:
+        if out_type is str:
           return self._SampleEncodeAndScoreAsPieces(text, num_samples, alpha, wor, include_best,
                                                     add_bos, add_eos, reverse, emit_unk_piece)
 
+        if out_type == 'serialized_proto' or out_type == 'proto':
+          return self._SampleEncodeAndScoreAsSerializedProto(text, num_samples, alpha, wor, include_best,
+                                                             add_bos, add_eos, reverse, emit_unk_piece)
+
+        if out_type == 'immutable_proto':
+          return self._SampleEncodeAndScoreAsImmutableProto(text, num_samples, alpha, wor, include_best,
+                                                            add_bos, add_eos, reverse, emit_unk_piece)
+
+        raise RuntimeError('unknown output type')
+
+
       if type(input) is list:
         return [_encode(n) for n in input]
 
index 1e2e1e0880489f20327cfd3c051c9da3a3ba5d29..f3a4f3044cc5a9017053dcc503b7c54f54df33e3 100644 (file)
@@ -2,6 +2,7 @@
 %include exception.i
 
 %{
+
 #include <iostream>
 #include <algorithm>
 #include <functional>
@@ -286,8 +287,10 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
 %ignore sentencepiece::SentencePieceProcessor::status;
 %ignore sentencepiece::ImmutableSentencePieceText::mutable_proto;
 %ignore sentencepiece::ImmutableSentencePieceText::pieces() const;
+%ignore sentencepiece::ImmutableSentencePieceText::ConvertToUnicodeSpans;
 %ignore sentencepiece::ImmutableNBestSentencePieceText::mutable_proto;
 %ignore sentencepiece::ImmutableNBestSentencePieceText::nbests() const;
+%ignore sentencepiece::ImmutableNBestSentencePieceText::ConvertToUnicodeSpans;
 
 %ignore sentencepiece::SentencePieceProcessor::Encode;
 %ignore sentencepiece::SentencePieceProcessor::SampleEncode;
@@ -481,6 +484,13 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
                                   sentencepiece::util::bytes);
   }
 
+  std::vector<sentencepiece::ImmutableSentencePieceText>
+      _DecodeIdsAsImmutableProtoBatch(
+          const std::vector<std::vector<int>> &ins, int num_threads) const {
+    DEFINE_DECODE_BATCH_FUNC_IMPL(DecodeIdsAsImmutableProto, int,
+                                  sentencepiece::ImmutableSentencePieceText);
+  }
+
   std::vector<std::string> _DecodePiecesBatch(
       const std::vector<std::vector<absl::string_view>> &ins, int num_threads) const {
     DEFINE_DECODE_BATCH_FUNC_IMPL(DecodePieces, std::string, std::string);
@@ -852,6 +862,8 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
         return self._NBestEncodeAsImmutableProto(text, nbest_size,
                                                  add_bos, add_eos, reverse, emit_unk_piece)
 
+      raise RuntimeError('unknown out_type')
+
     if type(input) is list:
       return [_encode(n) for n in input]
 
@@ -934,10 +946,21 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
       if out_type is int:
         return self._SampleEncodeAndScoreAsIds(text, num_samples, alpha, wor, include_best,
                                                add_bos, add_eos, reverse, emit_unk_piece)
-      else:
+      if out_type is str:
         return self._SampleEncodeAndScoreAsPieces(text, num_samples, alpha, wor, include_best,
                                                   add_bos, add_eos, reverse, emit_unk_piece)
 
+      if out_type == 'serialized_proto' or out_type == 'proto':
+        return self._SampleEncodeAndScoreAsSerializedProto(text, num_samples, alpha, wor, include_best,
+                                                           add_bos, add_eos, reverse, emit_unk_piece)
+
+      if out_type == 'immutable_proto':
+        return self._SampleEncodeAndScoreAsImmutableProto(text, num_samples, alpha, wor, include_best,
+                                                          add_bos, add_eos, reverse, emit_unk_piece)
+
+      raise RuntimeError('unknown output type')
+
+
     if type(input) is list:
       return [_encode(n) for n in input]
 
@@ -1187,7 +1210,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
 }
 
 %extend sentencepiece::ImmutableSentencePieceText {
-  ImmutableSentencePieceText_ImmutableSentencePiece pieces(int index) const {
+  ImmutableSentencePieceText_ImmutableSentencePiece _pieces(int index) const {
     if (index < 0 || index >= static_cast<int>($self->pieces_size())) {
       throw sentencepiece::util::Status(
           sentencepiece::util::StatusCode::kOutOfRange,
@@ -1197,19 +1220,25 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
   }
 
 %pythoncode {
+  def pieces(self, i):
+    return self._pieces(i)
+
   def __len__(self):
     return self.pieces_size()
 
   def __getitem__(self, i):
-    return self.pieces(i)
+    return self._pieces(i)
 
   def __eq__(self, other):
     return self.SerializeAsString() == other.SerializeAsString()
+
+  def __hash__(self):
+    return hash(self.SerializeAsString())
 }
 }
 
 %extend sentencepiece::ImmutableNBestSentencePieceText {
-  ImmutableSentencePieceText nbests(int index) const {
+  ImmutableSentencePieceText _nbests(int index) const {
     if (index < 0 || index >= static_cast<int>($self->nbests_size())) {
       throw sentencepiece::util::Status(
           sentencepiece::util::StatusCode::kOutOfRange,
@@ -1219,14 +1248,20 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
   }
 
 %pythoncode {
+  def __nbests__(self, i):
+    return self._nbests(i)
+
   def __len__(self):
     return self.nbests_size()
 
   def __getitem__(self, i):
-    return self.nbests(i)
+    return self._nbests(i)
 
   def __eq__(self, other):
     return self.SerializeAsString() == other.SerializeAsString()
+
+  def __hash__(self):
+    return hash(self.SerializeAsString())
 }
 }
 
index 9776b0f562a41c50b138dc908e1e6cd8f9b787f7..22e0708771f6807e9c69597f31e76098745d5ffe 100644 (file)
@@ -2811,6 +2811,7 @@ namespace swig {
 }
 
 
+
 #include <iostream>
 #include <algorithm>
 #include <functional>
@@ -3132,16 +3133,6 @@ SWIG_From_size_t  (size_t value)
 }
 
 
-  #define SWIG_From_double   PyFloat_FromDouble 
-
-
-SWIGINTERNINLINE PyObject *
-SWIG_From_float  (float value)
-{    
-  return SWIG_From_double  (value);
-}
-
-
 SWIGINTERN int
 SWIG_AsVal_double (PyObject *obj, double *val)
 {
@@ -3282,7 +3273,17 @@ SWIG_AsVal_int (PyObject * obj, int *val)
   return res;
 }
 
-SWIGINTERN sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece sentencepiece_ImmutableSentencePieceText_pieces(sentencepiece::ImmutableSentencePieceText const *self,int index){
+
+  #define SWIG_From_double   PyFloat_FromDouble 
+
+
+SWIGINTERNINLINE PyObject *
+SWIG_From_float  (float value)
+{    
+  return SWIG_From_double  (value);
+}
+
+SWIGINTERN sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece sentencepiece_ImmutableSentencePieceText__pieces(sentencepiece::ImmutableSentencePieceText const *self,int index){
     if (index < 0 || index >= static_cast<int>(self->pieces_size())) {
       throw sentencepiece::util::Status(
           sentencepiece::util::StatusCode::kOutOfRange,
@@ -3290,7 +3291,7 @@ SWIGINTERN sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece sent
     }
     return self->pieces(index);
   }
-SWIGINTERN sentencepiece::ImmutableSentencePieceText sentencepiece_ImmutableNBestSentencePieceText_nbests(sentencepiece::ImmutableNBestSentencePieceText const *self,int index){
+SWIGINTERN sentencepiece::ImmutableSentencePieceText sentencepiece_ImmutableNBestSentencePieceText__nbests(sentencepiece::ImmutableNBestSentencePieceText const *self,int index){
     if (index < 0 || index >= static_cast<int>(self->nbests_size())) {
       throw sentencepiece::util::Status(
           sentencepiece::util::StatusCode::kOutOfRange,
@@ -3590,6 +3591,10 @@ SWIGINTERN BytesArray sentencepiece_SentencePieceProcessor__DecodeIdsAsSerialize
     DEFINE_DECODE_BATCH_FUNC_IMPL(DecodeIdsAsSerializedProto, int,
                                   sentencepiece::util::bytes);
   }
+SWIGINTERN std::vector< sentencepiece::ImmutableSentencePieceText > sentencepiece_SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch(sentencepiece::SentencePieceProcessor const *self,std::vector< std::vector< int > > const &ins,int num_threads){
+    DEFINE_DECODE_BATCH_FUNC_IMPL(DecodeIdsAsImmutableProto, int,
+                                  sentencepiece::ImmutableSentencePieceText);
+  }
 SWIGINTERN std::vector< std::string > sentencepiece_SentencePieceProcessor__DecodePiecesBatch(sentencepiece::SentencePieceProcessor const *self,std::vector< std::vector< absl::string_view > > const &ins,int num_threads){
     DEFINE_DECODE_BATCH_FUNC_IMPL(DecodePieces, std::string, std::string);
   }
@@ -4070,6 +4075,44 @@ fail:
 }
 
 
+SWIGINTERN PyObject *_wrap_ImmutableSentencePieceText_pieces(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
+  PyObject *resultobj = 0;
+  sentencepiece::ImmutableSentencePieceText *arg1 = (sentencepiece::ImmutableSentencePieceText *) 0 ;
+  int arg2 ;
+  void *argp1 = 0 ;
+  int res1 = 0 ;
+  int val2 ;
+  int ecode2 = 0 ;
+  PyObject *swig_obj[2] ;
+  sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece result;
+  
+  if (!SWIG_Python_UnpackTuple(args, "ImmutableSentencePieceText_pieces", 2, 2, swig_obj)) SWIG_fail;
+  res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText, 0 |  0 );
+  if (!SWIG_IsOK(res1)) {
+    SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "ImmutableSentencePieceText_pieces" "', argument " "1"" of type '" "sentencepiece::ImmutableSentencePieceText const *""'"); 
+  }
+  arg1 = reinterpret_cast< sentencepiece::ImmutableSentencePieceText * >(argp1);
+  ecode2 = SWIG_AsVal_int(swig_obj[1], &val2);
+  if (!SWIG_IsOK(ecode2)) {
+    SWIG_exception_fail(SWIG_ArgError(ecode2), "in method '" "ImmutableSentencePieceText_pieces" "', argument " "2"" of type '" "int""'");
+  } 
+  arg2 = static_cast< int >(val2);
+  {
+    try {
+      result = ((sentencepiece::ImmutableSentencePieceText const *)arg1)->pieces(arg2);
+      ReleaseResultObject(resultobj);
+    }
+    catch (const sentencepiece::util::Status &status) {
+      SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
+    }
+  }
+  resultobj = SWIG_NewPointerObj((new sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece(static_cast< const sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece& >(result))), SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece, SWIG_POINTER_OWN |  0 );
+  return resultobj;
+fail:
+  return NULL;
+}
+
+
 SWIGINTERN PyObject *_wrap_ImmutableSentencePieceText_text(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
   PyObject *resultobj = 0;
   sentencepiece::ImmutableSentencePieceText *arg1 = (sentencepiece::ImmutableSentencePieceText *) 0 ;
@@ -4168,7 +4211,7 @@ fail:
 }
 
 
-SWIGINTERN PyObject *_wrap_ImmutableSentencePieceText_pieces(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
+SWIGINTERN PyObject *_wrap_ImmutableSentencePieceText__pieces(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
   PyObject *resultobj = 0;
   sentencepiece::ImmutableSentencePieceText *arg1 = (sentencepiece::ImmutableSentencePieceText *) 0 ;
   int arg2 ;
@@ -4179,20 +4222,20 @@ SWIGINTERN PyObject *_wrap_ImmutableSentencePieceText_pieces(PyObject *SWIGUNUSE
   PyObject *swig_obj[2] ;
   sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece result;
   
-  if (!SWIG_Python_UnpackTuple(args, "ImmutableSentencePieceText_pieces", 2, 2, swig_obj)) SWIG_fail;
+  if (!SWIG_Python_UnpackTuple(args, "ImmutableSentencePieceText__pieces", 2, 2, swig_obj)) SWIG_fail;
   res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText, 0 |  0 );
   if (!SWIG_IsOK(res1)) {
-    SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "ImmutableSentencePieceText_pieces" "', argument " "1"" of type '" "sentencepiece::ImmutableSentencePieceText const *""'"); 
+    SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "ImmutableSentencePieceText__pieces" "', argument " "1"" of type '" "sentencepiece::ImmutableSentencePieceText const *""'"); 
   }
   arg1 = reinterpret_cast< sentencepiece::ImmutableSentencePieceText * >(argp1);
   ecode2 = SWIG_AsVal_int(swig_obj[1], &val2);
   if (!SWIG_IsOK(ecode2)) {
-    SWIG_exception_fail(SWIG_ArgError(ecode2), "in method '" "ImmutableSentencePieceText_pieces" "', argument " "2"" of type '" "int""'");
+    SWIG_exception_fail(SWIG_ArgError(ecode2), "in method '" "ImmutableSentencePieceText__pieces" "', argument " "2"" of type '" "int""'");
   } 
   arg2 = static_cast< int >(val2);
   {
     try {
-      result = sentencepiece_ImmutableSentencePieceText_pieces((sentencepiece::ImmutableSentencePieceText const *)arg1,arg2);
+      result = sentencepiece_ImmutableSentencePieceText__pieces((sentencepiece::ImmutableSentencePieceText const *)arg1,arg2);
       ReleaseResultObject(resultobj);
     }
     catch (const sentencepiece::util::Status &status) {
@@ -4299,6 +4342,44 @@ fail:
 }
 
 
+SWIGINTERN PyObject *_wrap_ImmutableNBestSentencePieceText_nbests(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
+  PyObject *resultobj = 0;
+  sentencepiece::ImmutableNBestSentencePieceText *arg1 = (sentencepiece::ImmutableNBestSentencePieceText *) 0 ;
+  int arg2 ;
+  void *argp1 = 0 ;
+  int res1 = 0 ;
+  int val2 ;
+  int ecode2 = 0 ;
+  PyObject *swig_obj[2] ;
+  sentencepiece::ImmutableSentencePieceText result;
+  
+  if (!SWIG_Python_UnpackTuple(args, "ImmutableNBestSentencePieceText_nbests", 2, 2, swig_obj)) SWIG_fail;
+  res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__ImmutableNBestSentencePieceText, 0 |  0 );
+  if (!SWIG_IsOK(res1)) {
+    SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "ImmutableNBestSentencePieceText_nbests" "', argument " "1"" of type '" "sentencepiece::ImmutableNBestSentencePieceText const *""'"); 
+  }
+  arg1 = reinterpret_cast< sentencepiece::ImmutableNBestSentencePieceText * >(argp1);
+  ecode2 = SWIG_AsVal_int(swig_obj[1], &val2);
+  if (!SWIG_IsOK(ecode2)) {
+    SWIG_exception_fail(SWIG_ArgError(ecode2), "in method '" "ImmutableNBestSentencePieceText_nbests" "', argument " "2"" of type '" "int""'");
+  } 
+  arg2 = static_cast< int >(val2);
+  {
+    try {
+      result = ((sentencepiece::ImmutableNBestSentencePieceText const *)arg1)->nbests(arg2);
+      ReleaseResultObject(resultobj);
+    }
+    catch (const sentencepiece::util::Status &status) {
+      SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
+    }
+  }
+  resultobj = SWIG_NewPointerObj((new sentencepiece::ImmutableSentencePieceText(static_cast< const sentencepiece::ImmutableSentencePieceText& >(result))), SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText, SWIG_POINTER_OWN |  0 );
+  return resultobj;
+fail:
+  return NULL;
+}
+
+
 SWIGINTERN PyObject *_wrap_ImmutableNBestSentencePieceText_SerializeAsString(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
   PyObject *resultobj = 0;
   sentencepiece::ImmutableNBestSentencePieceText *arg1 = (sentencepiece::ImmutableNBestSentencePieceText *) 0 ;
@@ -4332,7 +4413,7 @@ fail:
 }
 
 
-SWIGINTERN PyObject *_wrap_ImmutableNBestSentencePieceText_nbests(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
+SWIGINTERN PyObject *_wrap_ImmutableNBestSentencePieceText__nbests(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
   PyObject *resultobj = 0;
   sentencepiece::ImmutableNBestSentencePieceText *arg1 = (sentencepiece::ImmutableNBestSentencePieceText *) 0 ;
   int arg2 ;
@@ -4343,20 +4424,20 @@ SWIGINTERN PyObject *_wrap_ImmutableNBestSentencePieceText_nbests(PyObject *SWIG
   PyObject *swig_obj[2] ;
   sentencepiece::ImmutableSentencePieceText result;
   
-  if (!SWIG_Python_UnpackTuple(args, "ImmutableNBestSentencePieceText_nbests", 2, 2, swig_obj)) SWIG_fail;
+  if (!SWIG_Python_UnpackTuple(args, "ImmutableNBestSentencePieceText__nbests", 2, 2, swig_obj)) SWIG_fail;
   res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__ImmutableNBestSentencePieceText, 0 |  0 );
   if (!SWIG_IsOK(res1)) {
-    SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "ImmutableNBestSentencePieceText_nbests" "', argument " "1"" of type '" "sentencepiece::ImmutableNBestSentencePieceText const *""'"); 
+    SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "ImmutableNBestSentencePieceText__nbests" "', argument " "1"" of type '" "sentencepiece::ImmutableNBestSentencePieceText const *""'"); 
   }
   arg1 = reinterpret_cast< sentencepiece::ImmutableNBestSentencePieceText * >(argp1);
   ecode2 = SWIG_AsVal_int(swig_obj[1], &val2);
   if (!SWIG_IsOK(ecode2)) {
-    SWIG_exception_fail(SWIG_ArgError(ecode2), "in method '" "ImmutableNBestSentencePieceText_nbests" "', argument " "2"" of type '" "int""'");
+    SWIG_exception_fail(SWIG_ArgError(ecode2), "in method '" "ImmutableNBestSentencePieceText__nbests" "', argument " "2"" of type '" "int""'");
   } 
   arg2 = static_cast< int >(val2);
   {
     try {
-      result = sentencepiece_ImmutableNBestSentencePieceText_nbests((sentencepiece::ImmutableNBestSentencePieceText const *)arg1,arg2);
+      result = sentencepiece_ImmutableNBestSentencePieceText__nbests((sentencepiece::ImmutableNBestSentencePieceText const *)arg1,arg2);
       ReleaseResultObject(resultobj);
     }
     catch (const sentencepiece::util::Status &status) {
@@ -6822,6 +6903,87 @@ fail:
 }
 
 
+SWIGINTERN PyObject *_wrap_SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
+  PyObject *resultobj = 0;
+  sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
+  std::vector< std::vector< int > > *arg2 = 0 ;
+  int arg3 ;
+  void *argp1 = 0 ;
+  int res1 = 0 ;
+  int val3 ;
+  int ecode3 = 0 ;
+  PyObject *swig_obj[3] ;
+  SwigValueWrapper< std::vector< sentencepiece::ImmutableSentencePieceText > > result;
+  
+  if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch", 3, 3, swig_obj)) SWIG_fail;
+  res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 |  0 );
+  if (!SWIG_IsOK(res1)) {
+    SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); 
+  }
+  arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
+  {
+    std::vector<std::vector<int>> *out = nullptr;
+    if (PyList_Check(swig_obj[1])) {
+      const size_t size = PyList_Size(swig_obj[1]);
+      out = new std::vector<std::vector<int>>(size);
+      for (size_t i = 0; i < size; ++i) {
+        PyObject *o = PyList_GetItem(swig_obj[1], i);
+        if (PyList_Check(o)) {
+          const size_t size2 = PyList_Size(o);
+          (*out)[i].resize(size2);
+          for (size_t j = 0; j < size2; ++j) {
+            PyObject *o2 = PyList_GetItem(o, j);
+            if (PyInt_Check(o2)) {
+              (*out)[i][j] = static_cast<int>(PyInt_AsLong(o2));
+            } else {
+              PyErr_SetString(PyExc_TypeError, "list must contain strings");
+              SWIG_fail;
+            }
+          }
+        } else {
+          PyErr_SetString(PyExc_TypeError, "not a list");
+          SWIG_fail;
+        }
+      }
+    } else {
+      PyErr_SetString(PyExc_TypeError,"not a list");
+      SWIG_fail;
+    }
+    arg2 = out;
+  }
+  ecode3 = SWIG_AsVal_int(swig_obj[2], &val3);
+  if (!SWIG_IsOK(ecode3)) {
+    SWIG_exception_fail(SWIG_ArgError(ecode3), "in method '" "SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch" "', argument " "3"" of type '" "int""'");
+  } 
+  arg3 = static_cast< int >(val3);
+  {
+    try {
+      result = sentencepiece_SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch((sentencepiece::SentencePieceProcessor const *)arg1,(std::vector< std::vector< int > > const &)*arg2,arg3);
+      ReleaseResultObject(resultobj);
+    }
+    catch (const sentencepiece::util::Status &status) {
+      SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
+    }
+  }
+  {
+    resultobj = PyList_New((&result)->size());
+    for (size_t i = 0; i < (&result)->size(); ++i) {
+      PyObject *obj = SWIG_NewPointerObj(new sentencepiece::ImmutableSentencePieceText((&result)->at(i)), SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText, SWIG_POINTER_OWN | 0);
+      PyList_SET_ITEM(resultobj, i, obj);
+    }
+  }
+  {
+    delete arg2;
+  }
+  return resultobj;
+fail:
+  {
+    delete arg2;
+  }
+  return NULL;
+}
+
+
 SWIGINTERN PyObject *_wrap_SentencePieceProcessor__DecodePiecesBatch(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
   PyObject *resultobj = 0;
   sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
@@ -8298,17 +8460,19 @@ static PyMethodDef SwigMethods[] = {
         { "new_ImmutableSentencePieceText", _wrap_new_ImmutableSentencePieceText, METH_NOARGS, NULL},
         { "delete_ImmutableSentencePieceText", _wrap_delete_ImmutableSentencePieceText, METH_O, NULL},
         { "ImmutableSentencePieceText_pieces_size", _wrap_ImmutableSentencePieceText_pieces_size, METH_O, NULL},
+        { "ImmutableSentencePieceText_pieces", _wrap_ImmutableSentencePieceText_pieces, METH_VARARGS, NULL},
         { "ImmutableSentencePieceText_text", _wrap_ImmutableSentencePieceText_text, METH_O, NULL},
         { "ImmutableSentencePieceText_score", _wrap_ImmutableSentencePieceText_score, METH_O, NULL},
         { "ImmutableSentencePieceText_SerializeAsString", _wrap_ImmutableSentencePieceText_SerializeAsString, METH_O, NULL},
-        { "ImmutableSentencePieceText_pieces", _wrap_ImmutableSentencePieceText_pieces, METH_VARARGS, NULL},
+        { "ImmutableSentencePieceText__pieces", _wrap_ImmutableSentencePieceText__pieces, METH_VARARGS, NULL},
         { "ImmutableSentencePieceText_swigregister", ImmutableSentencePieceText_swigregister, METH_O, NULL},
         { "ImmutableSentencePieceText_swiginit", ImmutableSentencePieceText_swiginit, METH_VARARGS, NULL},
         { "new_ImmutableNBestSentencePieceText", _wrap_new_ImmutableNBestSentencePieceText, METH_NOARGS, NULL},
         { "delete_ImmutableNBestSentencePieceText", _wrap_delete_ImmutableNBestSentencePieceText, METH_O, NULL},
         { "ImmutableNBestSentencePieceText_nbests_size", _wrap_ImmutableNBestSentencePieceText_nbests_size, METH_O, NULL},
-        { "ImmutableNBestSentencePieceText_SerializeAsString", _wrap_ImmutableNBestSentencePieceText_SerializeAsString, METH_O, NULL},
         { "ImmutableNBestSentencePieceText_nbests", _wrap_ImmutableNBestSentencePieceText_nbests, METH_VARARGS, NULL},
+        { "ImmutableNBestSentencePieceText_SerializeAsString", _wrap_ImmutableNBestSentencePieceText_SerializeAsString, METH_O, NULL},
+        { "ImmutableNBestSentencePieceText__nbests", _wrap_ImmutableNBestSentencePieceText__nbests, METH_VARARGS, NULL},
         { "ImmutableNBestSentencePieceText_swigregister", ImmutableNBestSentencePieceText_swigregister, METH_O, NULL},
         { "ImmutableNBestSentencePieceText_swiginit", ImmutableNBestSentencePieceText_swiginit, METH_VARARGS, NULL},
         { "new_SentencePieceProcessor", _wrap_new_SentencePieceProcessor, METH_NOARGS, NULL},
@@ -8350,6 +8514,7 @@ static PyMethodDef SwigMethods[] = {
         { "SentencePieceProcessor__DecodePiecesAsImmutableProto", _wrap_SentencePieceProcessor__DecodePiecesAsImmutableProto, METH_VARARGS, NULL},
         { "SentencePieceProcessor__DecodeIdsBatch", _wrap_SentencePieceProcessor__DecodeIdsBatch, METH_VARARGS, NULL},
         { "SentencePieceProcessor__DecodeIdsAsSerializedProtoBatch", _wrap_SentencePieceProcessor__DecodeIdsAsSerializedProtoBatch, METH_VARARGS, NULL},
+        { "SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch", _wrap_SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch, METH_VARARGS, NULL},
         { "SentencePieceProcessor__DecodePiecesBatch", _wrap_SentencePieceProcessor__DecodePiecesBatch, METH_VARARGS, NULL},
         { "SentencePieceProcessor__DecodePiecesAsSerializedProtoBatch", _wrap_SentencePieceProcessor__DecodePiecesAsSerializedProtoBatch, METH_VARARGS, NULL},
         { "SentencePieceProcessor__DecodePiecesAsImmutableProtoBatch", _wrap_SentencePieceProcessor__DecodePiecesAsImmutableProtoBatch, METH_VARARGS, NULL},
index 2f2c84aec542abfcb0aaf6469819b531486fa1fa..5e4af7f4d7264ab12363eed41451cb571c5f2ef5 100755 (executable)
@@ -266,6 +266,13 @@ class TestSentencepieceProcessor(unittest.TestCase):
     t4 = self.sp_.decode_pieces_as_serialized_proto(['foo', 'bar'])
     t5 = self.sp_.decode_ids_as_serialized_proto([20, 30])
 
+    y1 = self.sp_.encode(text, out_type='serialized_proto')
+    y2 = self.sp_.encode(
+        text, enable_sampling=True, out_type='serialized_proto')
+    y3 = self.sp_.nbest_encode(text, out_type='serialized_proto', nbest_size=10)
+    y4 = self.sp_.decode(['foo', 'bar'], out_type='serialized_proto')
+    y5 = self.sp_.decode([20, 30], out_type='serialized_proto')
+
     self.assertEqual(type(s1), bytes)
     self.assertEqual(type(s2), bytes)
     self.assertEqual(type(t2), bytes)
@@ -277,6 +284,92 @@ class TestSentencepieceProcessor(unittest.TestCase):
     self.assertEqual(s3, t3)
     self.assertEqual(s4, t4)
     self.assertEqual(s5, t5)
+    self.assertEqual(s1, y1)
+    self.assertEqual(s3, y3)
+    self.assertEqual(s4, y4)
+    self.assertEqual(s5, y5)
+
+    ids = self.jasp_.EncodeAsIds(text)
+    pieces = self.jasp_.EncodeAsPieces(text)
+    s1 = self.jasp_.EncodeAsSerializedProto(text)
+    s2 = self.jasp_.DecodeIdsAsSerializedProto(ids)
+    s3 = self.jasp_.DecodePiecesAsSerializedProto(ids)
+    self.assertEqual(s2, s1)
+    self.assertEqual(s3, s1)
+
+  def test_immutable_proto(self):
+    text = 'I saw a girl with a telescope.'
+    s1 = self.sp_.EncodeAsImmutableProto(text)
+    s2 = self.sp_.SampleEncodeAsImmutableProto(text, 10, 0.2)
+    s3 = self.sp_.NBestEncodeAsImmutableProto(text, 10)
+    s4 = self.sp_.DecodePiecesAsImmutableProto(['foo', 'bar'])
+    s5 = self.sp_.DecodeIdsAsImmutableProto([20, 30])
+
+    t1 = self.sp_.encode_as_immutable_proto(text)
+    t2 = self.sp_.sample_encode_as_immutable_proto(text, 10, 0.2)
+    t3 = self.sp_.nbest_encode_as_immutable_proto(text, 10)
+    t4 = self.sp_.decode_pieces_as_immutable_proto(['foo', 'bar'])
+    t5 = self.sp_.decode_ids_as_immutable_proto([20, 30])
+
+    y1 = self.sp_.encode(text, out_type='immutable_proto')
+    y2 = self.sp_.encode(text, enable_sampling=True, out_type='immutable_proto')
+    y3 = self.sp_.nbest_encode(text, out_type='immutable_proto', nbest_size=10)
+    y4 = self.sp_.decode(['foo', 'bar'], out_type='immutable_proto')
+    y5 = self.sp_.decode([20, 30], out_type='immutable_proto')
+
+    self.assertEqual(s1, t1)
+    self.assertEqual(s3, t3)
+    self.assertEqual(s4, t4)
+    self.assertEqual(s5, t5)
+    self.assertEqual(s1, y1)
+    self.assertEqual(s3, y3)
+    self.assertEqual(s4, y4)
+    self.assertEqual(s5, y5)
+
+    x1 = self.sp_.encode_as_serialized_proto(text)
+    x2 = self.sp_.sample_encode_as_serialized_proto(text, 10, 0.2)
+    x3 = self.sp_.nbest_encode_as_serialized_proto(text, 10)
+    x4 = self.sp_.decode_pieces_as_serialized_proto(['foo', 'bar'])
+    x5 = self.sp_.decode_ids_as_serialized_proto([20, 30])
+
+    self.assertEqual(x1, t1.SerializeAsString())
+    self.assertEqual(x3, t3.SerializeAsString())
+    self.assertEqual(x4, t4.SerializeAsString())
+    self.assertEqual(x5, t5.SerializeAsString())
+
+    v1 = self.sp_.EncodeAsIds(text)
+    v2 = self.sp_.EncodeAsPieces(text)
+    self.assertEqual([x.id() for x in s1], v1)
+    self.assertEqual([x.piece() for x in s1], v2)
+    self.assertEqual(text, s1.text())
+
+    surfaces1 = [s1.text()[x.begin():x.end()] for x in s1]
+    surfaces2 = [x.surface() for x in s1]
+    self.assertEqual(surfaces1, surfaces2)
+
+    ids = []
+    for i in range(s1.pieces_size()):
+      ids.append(s1.pieces(i).id())
+    self.assertEqual(ids, v1)
+
+    pieces = []
+    for i in range(s1.pieces_size()):
+      pieces.append(s1.pieces(i).piece())
+    self.assertEqual(pieces, v2)
+
+    # Japanese offset
+    s1 = self.jasp_.EncodeAsImmutableProto('吾輩は猫である。Hello world. ABC 123')
+    surfaces1 = [s1.text()[x.begin():x.end()] for x in s1]
+    surfaces2 = [x.surface() for x in s1]
+    self.assertEqual(surfaces1, surfaces2)
+
+    ids = [x.id() for x in s1]
+    s2 = self.jasp_.DecodeIdsAsImmutableProto(ids)
+    self.assertEqual(s2, s1)
+
+    pieces = [x.piece() for x in s1]
+    s2 = self.jasp_.DecodePiecesAsImmutableProto(pieces)
+    self.assertEqual(s2, s1)
 
   def test_new_api(self):
     sp = spm.SentencePieceProcessor(
@@ -386,49 +479,102 @@ class TestSentencepieceProcessor(unittest.TestCase):
     self.assertEqual(pieces, sp.encode(text, add_bos=False, add_eos=True))
 
   def test_sampling(self):
-    sp = spm.SentencePieceProcessor(
-        model_file=os.path.join('test', 'test_model.model'),
-        out_type=str,
-        enable_sampling=True)
-    ids = defaultdict(int)
-    for n in range(100):
-      ++ids[' '.join(sp.encode('hello world'))]
-    self.assertGreater(len(ids), 1)
-
-    ids2 = defaultdict(int)
-    for n in range(100):
-      ++ids2[' '.join(sp.encode('hello world', enable_sampling=False))]
-    self.assertEqual(len(ids2), 1)
+    sp = self.sp_
+
+    for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
+      ids = defaultdict(int)
+      for n in range(100):
+        out = sp.encode('hello world', out_type=out_type, enable_sampling=True)
+        if type(out) is list:
+          out = tuple(out)
+        ++ids[out]
+      self.assertGreater(len(ids), 1)
+
+      ids2 = defaultdict(int)
+      for n in range(100):
+        out = sp.encode('hello world', out_type=out_type, enable_sampling=False)
+        if type(out) is list:
+          out = tuple(out)
+        ++ids2[out]
+      self.assertEqual(len(ids2), 1)
+
+      out = sp.encode(['hello world', 'this is a test'],
+                      out_type=out_type,
+                      enable_sampling=True)
+      self.assertEqual(len(out), 2)
+      out = sp.encode(['hello world', 'this is a test'],
+                      out_type=out_type,
+                      enable_sampling=False)
+      self.assertEqual(len(out), 2)
 
   def test_nbest(self):
-    sp = spm.SentencePieceProcessor(
-        model_file=os.path.join('test', 'test_model.model'))
+    sp = self.sp_
     text = 'hello world'
-    results = sp.nbest_encode(text, nbest_size=10, out_type=str)
-    self.assertEqual(results, sp.NBestEncode(text, nbest_size=10, out_type=str))
-    for n in results:
-      self.assertEqual(sp.decode(n), text)
-    decoded = sp.decode(results)
-    for n in decoded:
-      self.assertEqual(n, text)
-    results = sp.nbest_encode(text, nbest_size=10, out_type=int)
-    self.assertEqual(results, sp.NBestEncode(text, nbest_size=10, out_type=int))
-    for n in results:
-      self.assertEqual(sp.decode(n), text)
-    decoded = sp.decode(results)
-    for n in decoded:
-      self.assertEqual(n, text)
+    text2 = 'I have a pen.'
+
+    for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
+      results = sp.nbest_encode(text, nbest_size=10, out_type=out_type)
+      self.assertEqual(results,
+                       sp.NBestEncode(text, nbest_size=10, out_type=out_type))
+
+      if out_type in [str, int]:
+        for n in results:
+          self.assertEqual(sp.decode(n), text)
+
+        for n in sp.decode(results):
+          self.assertEqual(n, text)
+
+    # batch test
+      results = sp.nbest_encode([text, text2], nbest_size=10, out_type=out_type)
+      self.assertEqual(
+          results,
+          sp.NBestEncode([text, text2], nbest_size=10, out_type=out_type))
+      self.assertEqual(len(results), 2)
+
+      if out_type in [str, int]:
+        for n in results[0]:
+          self.assertEqual(sp.decode(n), text)
+
+        for n in results[1]:
+          self.assertEqual(sp.decode(n), text2)
+
+        decoded = sp.decode(results[0])
+        self.assertEqual(len(decoded), 10)
+        for n in decoded:
+          self.assertEqual(n, text)
+        decoded = sp.decode(results[1])
+        self.assertEqual(len(decoded), 10)
+        for n in decoded:
+          self.assertEqual(n, text2)
 
   def test_sample_and_score(self):
-    sp = spm.SentencePieceProcessor(
-        model_file=os.path.join('test', 'test_model.model'))
+    sp = self.sp_
     text = 'hello world'
-    results = sp.sample_encode_and_score(text, wor=True, out_type=str)
-    for n in results:
-      self.assertEqual(sp.decode(n[0]), text)
-    results = sp.sample_encode_and_score(text, wor=True, out_type=int)
-    for n in results:
-      self.assertEqual(sp.decode(n[0]), text)
+    text2 = 'I have a pen.'
+    for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
+      results = sp.sample_encode_and_score(
+          text, wor=True, num_samples=10, out_type=out_type)
+      results = sp.SampleEncodeAndScore(
+          text, wor=False, num_samples=10, out_type=out_type)
+
+      if out_type in [str, int]:
+        for n in results:
+          self.assertEqual(sp.decode(n[0]), text)
+
+      results = sp.sample_encode_and_score([text, text2],
+                                           wor=True,
+                                           num_samples=10,
+                                           out_type=out_type)
+      results = sp.SampleEncodeAndScore([text, text2],
+                                        wor=True,
+                                        num_samples=10,
+                                        out_type=out_type)
+
+      if out_type in [str, int]:
+        for n in results[0]:
+          self.assertEqual(sp.decode(n[0]), text)
+        for n in results[1]:
+          self.assertEqual(sp.decode(n[0]), text2)
 
   def test_valid_range(self):
     size = self.sp_.piece_size()
@@ -452,65 +598,28 @@ class TestSentencepieceProcessor(unittest.TestCase):
     with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file:
       texts = file.readlines()
 
-    r1 = sp.encode(texts, out_type=str, num_threads=None)
-    r2 = sp.encode(texts, out_type=str, num_threads=1)
-    r3 = sp.encode(texts, out_type=str, num_threads=-1)
-    r4 = sp.encode(texts, out_type=str, num_threads=8)
-    r5 = [sp.encode(s, out_type=str) for s in texts]
-    self.assertEqual(r1, r2)
-    self.assertEqual(r1, r3)
-    self.assertEqual(r1, r4)
-    self.assertEqual(r1, r5)
-
-    d1 = sp.decode(r1, num_threads=None)
-    d2 = sp.decode(r2, num_threads=1)
-    d3 = sp.decode(r3, num_threads=-1)
-    d4 = sp.decode(r4, num_threads=8)
-    d5 = [sp.decode(s) for s in r5]
-    self.assertEqual(d1, d2)
-    self.assertEqual(d1, d3)
-    self.assertEqual(d1, d4)
-    self.assertEqual(d1, d5)
-
-    r1 = sp.encode(texts, out_type=int, num_threads=None)
-    r2 = sp.encode(texts, out_type=int, num_threads=1)
-    r3 = sp.encode(texts, out_type=int, num_threads=-1)
-    r4 = sp.encode(texts, out_type=int, num_threads=8)
-    r5 = [sp.encode(s, out_type=int) for s in texts]
-    self.assertEqual(r1, r2)
-    self.assertEqual(r1, r3)
-    self.assertEqual(r1, r4)
-    self.assertEqual(r1, r5)
-
-    d1 = sp.decode(r1, num_threads=None)
-    d2 = sp.decode(r2, num_threads=1)
-    d3 = sp.decode(r3, num_threads=-1)
-    d4 = sp.decode(r4, num_threads=8)
-    d5 = [sp.decode(s) for s in r5]
-    self.assertEqual(d1, d2)
-    self.assertEqual(d1, d3)
-    self.assertEqual(d1, d4)
-    self.assertEqual(d1, d5)
-
-    r1 = sp.encode(texts, out_type='serialized_proto', num_threads=None)
-    r2 = sp.encode(texts, out_type='serialized_proto', num_threads=1)
-    r3 = sp.encode(texts, out_type='serialized_proto', num_threads=-1)
-    r4 = sp.encode(texts, out_type='serialized_proto', num_threads=8)
-    r5 = [sp.encode(s, out_type='serialized_proto') for s in texts]
-    self.assertEqual(r1, r2)
-    self.assertEqual(r1, r3)
-    self.assertEqual(r1, r4)
-    self.assertEqual(r1, r5)
-
-    r1 = sp.encode(texts, out_type='immutable_proto', num_threads=None)
-    r2 = sp.encode(texts, out_type='immutable_proto', num_threads=1)
-    r3 = sp.encode(texts, out_type='immutable_proto', num_threads=-1)
-    r4 = sp.encode(texts, out_type='immutable_proto', num_threads=8)
-    r5 = [sp.encode(s, out_type='immutable_proto') for s in texts]
-    self.assertEqual(r1, r2)
-    self.assertEqual(r1, r3)
-    self.assertEqual(r1, r4)
-    self.assertEqual(r1, r5)
+    for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
+      r1 = sp.encode(texts, out_type=out_type, num_threads=None)
+      r2 = sp.encode(texts, out_type=out_type, num_threads=1)
+      r3 = sp.encode(texts, out_type=out_type, num_threads=-1)
+      r4 = sp.encode(texts, out_type=out_type, num_threads=8)
+      r5 = [sp.encode(s, out_type=out_type) for s in texts]
+      self.assertEqual(r1, r2)
+      self.assertEqual(r1, r3)
+      self.assertEqual(r1, r4)
+      self.assertEqual(r1, r5)
+
+      if out_type in [str, int]:
+        d1 = sp.decode(r1, num_threads=None)
+        d2 = sp.decode(r2, num_threads=1)
+        d3 = sp.decode(r3, num_threads=-1)
+        d4 = sp.decode(r4, num_threads=8)
+        d5 = [sp.decode(s) for s in r5]
+
+        self.assertEqual(d1, d2)
+        self.assertEqual(d1, d3)
+        self.assertEqual(d1, d4)
+        self.assertEqual(d1, d5)
 
     e1 = sp.calculate_entropy(texts, alpha=1.0, num_threads=10)
     e2 = sp.CalculateEntropy(texts, alpha=1.0, num_threads=10)
index 482a45bf1677f75fadbf543a6a37c5d1229308d8..2a5c39932ab8698889a1d66ac9b244faf5182017 100644 (file)
@@ -55,6 +55,34 @@ std::vector<absl::string_view> ToPieceArray(const std::vector<std::string> &v) {
   return out;
 }
 
+void ConvertToUnicodeSpansInternal(SentencePieceText *spt) {
+  if (spt == nullptr) return;
+
+  std::vector<int> utf8_to_unicode(spt->text().size() + 1, 0);
+  absl::string_view str = spt->text();
+  size_t prev = 0;
+  int ulen = 0;
+  while (!str.empty()) {
+    const size_t mblen = string_util::OneCharLen(str.data());
+    for (int i = prev; i < prev + mblen; ++i) {
+      utf8_to_unicode[i] = ulen;
+    }
+    ++ulen;
+    prev += mblen;
+    str.remove_prefix(mblen);
+  }
+  utf8_to_unicode[prev] = ulen;
+
+  auto clip = [&](int s) {
+    return std::min<int>(std::max<int>(0, s), utf8_to_unicode.size() - 1);
+  };
+
+  for (auto &piece : *(spt->mutable_pieces())) {
+    piece.set_begin(utf8_to_unicode[clip(piece.begin())]);
+    piece.set_end(utf8_to_unicode[clip(piece.end())]);
+  }
+}
+
 }  // namespace
 
 ImmutableSentencePieceText::ImmutableSentencePieceText()
@@ -132,6 +160,10 @@ SentencePieceText *ImmutableSentencePieceText::mutable_proto() {
   return rep_.get();
 }
 
+void ImmutableSentencePieceText::ConvertToUnicodeSpans() {
+  ConvertToUnicodeSpansInternal(mutable_proto());
+}
+
 util::bytes ImmutableSentencePieceText::SerializeAsString() const {
   return spt_->SerializeAsString();
 }
@@ -164,6 +196,13 @@ NBestSentencePieceText *ImmutableNBestSentencePieceText::mutable_proto() {
   return rep_.get();
 }
 
+void ImmutableNBestSentencePieceText::ConvertToUnicodeSpans() {
+  if (!mutable_proto()) return;
+  for (auto &spt : *(mutable_proto()->mutable_nbests())) {
+    ConvertToUnicodeSpansInternal(&spt);
+  }
+}
+
 util::bytes ImmutableNBestSentencePieceText::SerializeAsString() const {
   return rep_ ? rep_->SerializeAsString() : "";
 }
@@ -1048,34 +1087,6 @@ std::string SentencePieceProcessor::serialized_model_proto() const {
 // std::random_device.
 void SetRandomGeneratorSeed(unsigned int seed);
 
-void ConvertToUnicodeSpans(SentencePieceText *spt) {
-  if (spt == nullptr) return;
-
-  std::vector<int> utf8_to_unicode(spt->text().size() + 1, 0);
-  absl::string_view str = spt->text();
-  size_t prev = 0;
-  int ulen = 0;
-  while (!str.empty()) {
-    const size_t mblen = string_util::OneCharLen(str.data());
-    for (int i = prev; i < prev + mblen; ++i) {
-      utf8_to_unicode[i] = ulen;
-    }
-    ++ulen;
-    prev += mblen;
-    str.remove_prefix(mblen);
-  }
-  utf8_to_unicode[prev] = ulen;
-
-  auto clip = [&](int s) {
-    return std::min<int>(std::max<int>(0, s), utf8_to_unicode.size() - 1);
-  };
-
-  for (auto &piece : *(spt->mutable_pieces())) {
-    piece.set_begin(utf8_to_unicode[clip(piece.begin())]);
-    piece.set_end(utf8_to_unicode[clip(piece.end())]);
-  }
-}
-
 namespace io {
 util::Status LoadModelProto(absl::string_view filename,
                             ModelProto *model_proto) {
index b7fae6a3defe6462b5de3a0894ad268ed8806379..d107a2a4d8b515c69da0a151a36322c3b419b291 100644 (file)
@@ -25,8 +25,8 @@
 #ifndef SWIG
 namespace absl {
 using std::string_view;
-}
-#endif  // SWIG
+}  // namespace absl
+#endif
 
 namespace sentencepiece {
 namespace util {
@@ -196,6 +196,9 @@ class ImmutableSentencePieceText {
   // it returns the raw pointer managed by the shared_ptr.
   SentencePieceText *mutable_proto();
 
+  // Converts the utf8 byte spans into Unicode char span.
+  void ConvertToUnicodeSpans();
+
   friend class ImmutableNBestSentencePieceText;
 
  private:
@@ -225,6 +228,8 @@ class ImmutableNBestSentencePieceText {
   // it returns the raw pointer managed by the shared_ptr.
   NBestSentencePieceText *mutable_proto();
 
+  void ConvertToUnicodeSpans();
+
  private:
   std::shared_ptr<NBestSentencePieceText> rep_;
 };
@@ -415,14 +420,16 @@ class SentencePieceProcessor {
   virtual util::Status Decode(const std::vector<int> &ids,
                               SentencePieceText *spt) const;
 
-#ifdef SWIG
+#ifdef SWIGPYTHON
+#define CONVERT_TO_UNICODE_SPAN output.ConvertToUnicodeSpans();
 #define SPP_SWIG_CHECK_AND_THROW \
   if (!status.ok()) throw status;
 #else
+#define CONVERT_TO_UNICODE_SPAN
 #define SPP_SWIG_CHECK_AND_THROW \
   if (!status.ok()) {            \
   }
-#endif  // SWIG
+#endif  // SWIGPYTHON
 
 #define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
   OutType output;                                           \
@@ -439,6 +446,7 @@ class SentencePieceProcessor {
 #define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...)      \
   OutType output;                                                    \
   const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
+  CONVERT_TO_UNICODE_SPAN;                                           \
   SPP_SWIG_CHECK_AND_THROW;                                          \
   return output;
 
@@ -707,9 +715,6 @@ class SentencePieceProcessor {
 // std::random_device.
 void SetRandomGeneratorSeed(unsigned int seed);
 
-// Converts the utf8 byte spans into Unicode char span.
-void ConvertToUnicodeSpans(SentencePieceText *spt);
-
 #ifndef SWIG
 // IO related functions to absorb model formats.
 namespace io {
index ff55aeb1cf3e17589b3a142e8fd32c69d5a25837..f05dc5d1832370068c58591ae854628792914929 100644 (file)
@@ -1657,11 +1657,12 @@ TEST(SentencePieceProcessorTest, ImmutableNBestSentencePieceTextTest) {
 
 TEST(SentencePieceProcessorTest, ConvertToUnicodeSpansTest) {
   auto make_spt = [&](const std::vector<std::string> &tokens) {
-    SentencePieceText spt;
+    ImmutableSentencePieceText ispt;
+    auto *spt = ispt.mutable_proto();
     int prev = 0;
     std::string text;
     for (const auto &tok : tokens) {
-      auto *piece = spt.add_pieces();
+      auto *piece = spt->add_pieces();
       piece->set_surface(tok);
       piece->set_piece(tok);
       piece->set_begin(prev);
@@ -1669,9 +1670,9 @@ TEST(SentencePieceProcessorTest, ConvertToUnicodeSpansTest) {
       prev += tok.size();
       text += tok;
     }
-    spt.set_text(text);
-    ConvertToUnicodeSpans(&spt);
-    return spt;
+    spt->set_text(text);
+    ispt.ConvertToUnicodeSpans();
+    return ispt;
   };
 
   {