From: Taku Kudo Date: Tue, 2 Aug 2022 17:24:53 +0000 (+0900) Subject: Adds more unittests X-Git-Tag: archive/raspbian/0.1.97-3+rpi1^2~15 X-Git-Url: https://dgit.raspbian.org/?a=commitdiff_plain;h=9d015be1d31fe6d2e54433a4253458635f4daa75;p=sentencepiece.git Adds more unittests Signed-off-by: Kentaro Hayashi Gbp-Pq: Name 0013-Adds-more-unittests.patch --- diff --git a/python/src/sentencepiece/__init__.py b/python/src/sentencepiece/__init__.py index 69a9825..07acb94 100644 --- a/python/src/sentencepiece/__init__.py +++ b/python/src/sentencepiece/__init__.py @@ -98,6 +98,9 @@ class ImmutableSentencePieceText(object): def pieces_size(self): return _sentencepiece.ImmutableSentencePieceText_pieces_size(self) + def pieces(self, index): + return _sentencepiece.ImmutableSentencePieceText_pieces(self, index) + def text(self): return _sentencepiece.ImmutableSentencePieceText_text(self) @@ -107,18 +110,24 @@ class ImmutableSentencePieceText(object): def SerializeAsString(self): return _sentencepiece.ImmutableSentencePieceText_SerializeAsString(self) - def pieces(self, index): - return _sentencepiece.ImmutableSentencePieceText_pieces(self, index) + def _pieces(self, index): + return _sentencepiece.ImmutableSentencePieceText__pieces(self, index) + + def pieces(self, i): + return self._pieces(i) def __len__(self): return self.pieces_size() def __getitem__(self, i): - return self.pieces(i) + return self._pieces(i) def __eq__(self, other): return self.SerializeAsString() == other.SerializeAsString() + def __hash__(self): + return hash(self.SerializeAsString()) + # Register ImmutableSentencePieceText in _sentencepiece: _sentencepiece.ImmutableSentencePieceText_swigregister(ImmutableSentencePieceText) @@ -134,21 +143,30 @@ class ImmutableNBestSentencePieceText(object): def nbests_size(self): return _sentencepiece.ImmutableNBestSentencePieceText_nbests_size(self) + def nbests(self, index): + return _sentencepiece.ImmutableNBestSentencePieceText_nbests(self, index) + def SerializeAsString(self): return _sentencepiece.ImmutableNBestSentencePieceText_SerializeAsString(self) - def nbests(self, index): - return _sentencepiece.ImmutableNBestSentencePieceText_nbests(self, index) + def _nbests(self, index): + return _sentencepiece.ImmutableNBestSentencePieceText__nbests(self, index) + + def __nbests__(self, i): + return self._nbests(i) def __len__(self): return self.nbests_size() def __getitem__(self, i): - return self.nbests(i) + return self._nbests(i) def __eq__(self, other): return self.SerializeAsString() == other.SerializeAsString() + def __hash__(self): + return hash(self.SerializeAsString()) + # Register ImmutableNBestSentencePieceText in _sentencepiece: _sentencepiece.ImmutableNBestSentencePieceText_swigregister(ImmutableNBestSentencePieceText) @@ -272,6 +290,9 @@ class SentencePieceProcessor(object): def _DecodeIdsAsSerializedProtoBatch(self, ins, num_threads): return _sentencepiece.SentencePieceProcessor__DecodeIdsAsSerializedProtoBatch(self, ins, num_threads) + def _DecodeIdsAsImmutableProtoBatch(self, ins, num_threads): + return _sentencepiece.SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch(self, ins, num_threads) + def _DecodePiecesBatch(self, ins, num_threads): return _sentencepiece.SentencePieceProcessor__DecodePiecesBatch(self, ins, num_threads) @@ -539,6 +560,8 @@ class SentencePieceProcessor(object): return self._NBestEncodeAsImmutableProto(text, nbest_size, add_bos, add_eos, reverse, emit_unk_piece) + raise RuntimeError('unknown out_type') + if type(input) is list: return [_encode(n) for n in input] @@ -621,10 +644,21 @@ class SentencePieceProcessor(object): if out_type is int: return self._SampleEncodeAndScoreAsIds(text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece) - else: + if out_type is str: return self._SampleEncodeAndScoreAsPieces(text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece) + if out_type == 'serialized_proto' or out_type == 'proto': + return self._SampleEncodeAndScoreAsSerializedProto(text, num_samples, alpha, wor, include_best, + add_bos, add_eos, reverse, emit_unk_piece) + + if out_type == 'immutable_proto': + return self._SampleEncodeAndScoreAsImmutableProto(text, num_samples, alpha, wor, include_best, + add_bos, add_eos, reverse, emit_unk_piece) + + raise RuntimeError('unknown output type') + + if type(input) is list: return [_encode(n) for n in input] diff --git a/python/src/sentencepiece/sentencepiece.i b/python/src/sentencepiece/sentencepiece.i index 1e2e1e0..f3a4f30 100644 --- a/python/src/sentencepiece/sentencepiece.i +++ b/python/src/sentencepiece/sentencepiece.i @@ -2,6 +2,7 @@ %include exception.i %{ + #include #include #include @@ -286,8 +287,10 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { %ignore sentencepiece::SentencePieceProcessor::status; %ignore sentencepiece::ImmutableSentencePieceText::mutable_proto; %ignore sentencepiece::ImmutableSentencePieceText::pieces() const; +%ignore sentencepiece::ImmutableSentencePieceText::ConvertToUnicodeSpans; %ignore sentencepiece::ImmutableNBestSentencePieceText::mutable_proto; %ignore sentencepiece::ImmutableNBestSentencePieceText::nbests() const; +%ignore sentencepiece::ImmutableNBestSentencePieceText::ConvertToUnicodeSpans; %ignore sentencepiece::SentencePieceProcessor::Encode; %ignore sentencepiece::SentencePieceProcessor::SampleEncode; @@ -481,6 +484,13 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { sentencepiece::util::bytes); } + std::vector + _DecodeIdsAsImmutableProtoBatch( + const std::vector> &ins, int num_threads) const { + DEFINE_DECODE_BATCH_FUNC_IMPL(DecodeIdsAsImmutableProto, int, + sentencepiece::ImmutableSentencePieceText); + } + std::vector _DecodePiecesBatch( const std::vector> &ins, int num_threads) const { DEFINE_DECODE_BATCH_FUNC_IMPL(DecodePieces, std::string, std::string); @@ -852,6 +862,8 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { return self._NBestEncodeAsImmutableProto(text, nbest_size, add_bos, add_eos, reverse, emit_unk_piece) + raise RuntimeError('unknown out_type') + if type(input) is list: return [_encode(n) for n in input] @@ -934,10 +946,21 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { if out_type is int: return self._SampleEncodeAndScoreAsIds(text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece) - else: + if out_type is str: return self._SampleEncodeAndScoreAsPieces(text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece) + if out_type == 'serialized_proto' or out_type == 'proto': + return self._SampleEncodeAndScoreAsSerializedProto(text, num_samples, alpha, wor, include_best, + add_bos, add_eos, reverse, emit_unk_piece) + + if out_type == 'immutable_proto': + return self._SampleEncodeAndScoreAsImmutableProto(text, num_samples, alpha, wor, include_best, + add_bos, add_eos, reverse, emit_unk_piece) + + raise RuntimeError('unknown output type') + + if type(input) is list: return [_encode(n) for n in input] @@ -1187,7 +1210,7 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { } %extend sentencepiece::ImmutableSentencePieceText { - ImmutableSentencePieceText_ImmutableSentencePiece pieces(int index) const { + ImmutableSentencePieceText_ImmutableSentencePiece _pieces(int index) const { if (index < 0 || index >= static_cast($self->pieces_size())) { throw sentencepiece::util::Status( sentencepiece::util::StatusCode::kOutOfRange, @@ -1197,19 +1220,25 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { } %pythoncode { + def pieces(self, i): + return self._pieces(i) + def __len__(self): return self.pieces_size() def __getitem__(self, i): - return self.pieces(i) + return self._pieces(i) def __eq__(self, other): return self.SerializeAsString() == other.SerializeAsString() + + def __hash__(self): + return hash(self.SerializeAsString()) } } %extend sentencepiece::ImmutableNBestSentencePieceText { - ImmutableSentencePieceText nbests(int index) const { + ImmutableSentencePieceText _nbests(int index) const { if (index < 0 || index >= static_cast($self->nbests_size())) { throw sentencepiece::util::Status( sentencepiece::util::StatusCode::kOutOfRange, @@ -1219,14 +1248,20 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { } %pythoncode { + def __nbests__(self, i): + return self._nbests(i) + def __len__(self): return self.nbests_size() def __getitem__(self, i): - return self.nbests(i) + return self._nbests(i) def __eq__(self, other): return self.SerializeAsString() == other.SerializeAsString() + + def __hash__(self): + return hash(self.SerializeAsString()) } } diff --git a/python/src/sentencepiece/sentencepiece_wrap.cxx b/python/src/sentencepiece/sentencepiece_wrap.cxx index 9776b0f..22e0708 100644 --- a/python/src/sentencepiece/sentencepiece_wrap.cxx +++ b/python/src/sentencepiece/sentencepiece_wrap.cxx @@ -2811,6 +2811,7 @@ namespace swig { } + #include #include #include @@ -3132,16 +3133,6 @@ SWIG_From_size_t (size_t value) } - #define SWIG_From_double PyFloat_FromDouble - - -SWIGINTERNINLINE PyObject * -SWIG_From_float (float value) -{ - return SWIG_From_double (value); -} - - SWIGINTERN int SWIG_AsVal_double (PyObject *obj, double *val) { @@ -3282,7 +3273,17 @@ SWIG_AsVal_int (PyObject * obj, int *val) return res; } -SWIGINTERN sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece sentencepiece_ImmutableSentencePieceText_pieces(sentencepiece::ImmutableSentencePieceText const *self,int index){ + + #define SWIG_From_double PyFloat_FromDouble + + +SWIGINTERNINLINE PyObject * +SWIG_From_float (float value) +{ + return SWIG_From_double (value); +} + +SWIGINTERN sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece sentencepiece_ImmutableSentencePieceText__pieces(sentencepiece::ImmutableSentencePieceText const *self,int index){ if (index < 0 || index >= static_cast(self->pieces_size())) { throw sentencepiece::util::Status( sentencepiece::util::StatusCode::kOutOfRange, @@ -3290,7 +3291,7 @@ SWIGINTERN sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece sent } return self->pieces(index); } -SWIGINTERN sentencepiece::ImmutableSentencePieceText sentencepiece_ImmutableNBestSentencePieceText_nbests(sentencepiece::ImmutableNBestSentencePieceText const *self,int index){ +SWIGINTERN sentencepiece::ImmutableSentencePieceText sentencepiece_ImmutableNBestSentencePieceText__nbests(sentencepiece::ImmutableNBestSentencePieceText const *self,int index){ if (index < 0 || index >= static_cast(self->nbests_size())) { throw sentencepiece::util::Status( sentencepiece::util::StatusCode::kOutOfRange, @@ -3590,6 +3591,10 @@ SWIGINTERN BytesArray sentencepiece_SentencePieceProcessor__DecodeIdsAsSerialize DEFINE_DECODE_BATCH_FUNC_IMPL(DecodeIdsAsSerializedProto, int, sentencepiece::util::bytes); } +SWIGINTERN std::vector< sentencepiece::ImmutableSentencePieceText > sentencepiece_SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch(sentencepiece::SentencePieceProcessor const *self,std::vector< std::vector< int > > const &ins,int num_threads){ + DEFINE_DECODE_BATCH_FUNC_IMPL(DecodeIdsAsImmutableProto, int, + sentencepiece::ImmutableSentencePieceText); + } SWIGINTERN std::vector< std::string > sentencepiece_SentencePieceProcessor__DecodePiecesBatch(sentencepiece::SentencePieceProcessor const *self,std::vector< std::vector< absl::string_view > > const &ins,int num_threads){ DEFINE_DECODE_BATCH_FUNC_IMPL(DecodePieces, std::string, std::string); } @@ -4070,6 +4075,44 @@ fail: } +SWIGINTERN PyObject *_wrap_ImmutableSentencePieceText_pieces(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::ImmutableSentencePieceText *arg1 = (sentencepiece::ImmutableSentencePieceText *) 0 ; + int arg2 ; + void *argp1 = 0 ; + int res1 = 0 ; + int val2 ; + int ecode2 = 0 ; + PyObject *swig_obj[2] ; + sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece result; + + if (!SWIG_Python_UnpackTuple(args, "ImmutableSentencePieceText_pieces", 2, 2, swig_obj)) SWIG_fail; + res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "ImmutableSentencePieceText_pieces" "', argument " "1"" of type '" "sentencepiece::ImmutableSentencePieceText const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::ImmutableSentencePieceText * >(argp1); + ecode2 = SWIG_AsVal_int(swig_obj[1], &val2); + if (!SWIG_IsOK(ecode2)) { + SWIG_exception_fail(SWIG_ArgError(ecode2), "in method '" "ImmutableSentencePieceText_pieces" "', argument " "2"" of type '" "int""'"); + } + arg2 = static_cast< int >(val2); + { + try { + result = ((sentencepiece::ImmutableSentencePieceText const *)arg1)->pieces(arg2); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + resultobj = SWIG_NewPointerObj((new sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece(static_cast< const sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece& >(result))), SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece, SWIG_POINTER_OWN | 0 ); + return resultobj; +fail: + return NULL; +} + + SWIGINTERN PyObject *_wrap_ImmutableSentencePieceText_text(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; sentencepiece::ImmutableSentencePieceText *arg1 = (sentencepiece::ImmutableSentencePieceText *) 0 ; @@ -4168,7 +4211,7 @@ fail: } -SWIGINTERN PyObject *_wrap_ImmutableSentencePieceText_pieces(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { +SWIGINTERN PyObject *_wrap_ImmutableSentencePieceText__pieces(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; sentencepiece::ImmutableSentencePieceText *arg1 = (sentencepiece::ImmutableSentencePieceText *) 0 ; int arg2 ; @@ -4179,20 +4222,20 @@ SWIGINTERN PyObject *_wrap_ImmutableSentencePieceText_pieces(PyObject *SWIGUNUSE PyObject *swig_obj[2] ; sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece result; - if (!SWIG_Python_UnpackTuple(args, "ImmutableSentencePieceText_pieces", 2, 2, swig_obj)) SWIG_fail; + if (!SWIG_Python_UnpackTuple(args, "ImmutableSentencePieceText__pieces", 2, 2, swig_obj)) SWIG_fail; res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText, 0 | 0 ); if (!SWIG_IsOK(res1)) { - SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "ImmutableSentencePieceText_pieces" "', argument " "1"" of type '" "sentencepiece::ImmutableSentencePieceText const *""'"); + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "ImmutableSentencePieceText__pieces" "', argument " "1"" of type '" "sentencepiece::ImmutableSentencePieceText const *""'"); } arg1 = reinterpret_cast< sentencepiece::ImmutableSentencePieceText * >(argp1); ecode2 = SWIG_AsVal_int(swig_obj[1], &val2); if (!SWIG_IsOK(ecode2)) { - SWIG_exception_fail(SWIG_ArgError(ecode2), "in method '" "ImmutableSentencePieceText_pieces" "', argument " "2"" of type '" "int""'"); + SWIG_exception_fail(SWIG_ArgError(ecode2), "in method '" "ImmutableSentencePieceText__pieces" "', argument " "2"" of type '" "int""'"); } arg2 = static_cast< int >(val2); { try { - result = sentencepiece_ImmutableSentencePieceText_pieces((sentencepiece::ImmutableSentencePieceText const *)arg1,arg2); + result = sentencepiece_ImmutableSentencePieceText__pieces((sentencepiece::ImmutableSentencePieceText const *)arg1,arg2); ReleaseResultObject(resultobj); } catch (const sentencepiece::util::Status &status) { @@ -4299,6 +4342,44 @@ fail: } +SWIGINTERN PyObject *_wrap_ImmutableNBestSentencePieceText_nbests(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::ImmutableNBestSentencePieceText *arg1 = (sentencepiece::ImmutableNBestSentencePieceText *) 0 ; + int arg2 ; + void *argp1 = 0 ; + int res1 = 0 ; + int val2 ; + int ecode2 = 0 ; + PyObject *swig_obj[2] ; + sentencepiece::ImmutableSentencePieceText result; + + if (!SWIG_Python_UnpackTuple(args, "ImmutableNBestSentencePieceText_nbests", 2, 2, swig_obj)) SWIG_fail; + res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__ImmutableNBestSentencePieceText, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "ImmutableNBestSentencePieceText_nbests" "', argument " "1"" of type '" "sentencepiece::ImmutableNBestSentencePieceText const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::ImmutableNBestSentencePieceText * >(argp1); + ecode2 = SWIG_AsVal_int(swig_obj[1], &val2); + if (!SWIG_IsOK(ecode2)) { + SWIG_exception_fail(SWIG_ArgError(ecode2), "in method '" "ImmutableNBestSentencePieceText_nbests" "', argument " "2"" of type '" "int""'"); + } + arg2 = static_cast< int >(val2); + { + try { + result = ((sentencepiece::ImmutableNBestSentencePieceText const *)arg1)->nbests(arg2); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + resultobj = SWIG_NewPointerObj((new sentencepiece::ImmutableSentencePieceText(static_cast< const sentencepiece::ImmutableSentencePieceText& >(result))), SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText, SWIG_POINTER_OWN | 0 ); + return resultobj; +fail: + return NULL; +} + + SWIGINTERN PyObject *_wrap_ImmutableNBestSentencePieceText_SerializeAsString(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; sentencepiece::ImmutableNBestSentencePieceText *arg1 = (sentencepiece::ImmutableNBestSentencePieceText *) 0 ; @@ -4332,7 +4413,7 @@ fail: } -SWIGINTERN PyObject *_wrap_ImmutableNBestSentencePieceText_nbests(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { +SWIGINTERN PyObject *_wrap_ImmutableNBestSentencePieceText__nbests(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; sentencepiece::ImmutableNBestSentencePieceText *arg1 = (sentencepiece::ImmutableNBestSentencePieceText *) 0 ; int arg2 ; @@ -4343,20 +4424,20 @@ SWIGINTERN PyObject *_wrap_ImmutableNBestSentencePieceText_nbests(PyObject *SWIG PyObject *swig_obj[2] ; sentencepiece::ImmutableSentencePieceText result; - if (!SWIG_Python_UnpackTuple(args, "ImmutableNBestSentencePieceText_nbests", 2, 2, swig_obj)) SWIG_fail; + if (!SWIG_Python_UnpackTuple(args, "ImmutableNBestSentencePieceText__nbests", 2, 2, swig_obj)) SWIG_fail; res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__ImmutableNBestSentencePieceText, 0 | 0 ); if (!SWIG_IsOK(res1)) { - SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "ImmutableNBestSentencePieceText_nbests" "', argument " "1"" of type '" "sentencepiece::ImmutableNBestSentencePieceText const *""'"); + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "ImmutableNBestSentencePieceText__nbests" "', argument " "1"" of type '" "sentencepiece::ImmutableNBestSentencePieceText const *""'"); } arg1 = reinterpret_cast< sentencepiece::ImmutableNBestSentencePieceText * >(argp1); ecode2 = SWIG_AsVal_int(swig_obj[1], &val2); if (!SWIG_IsOK(ecode2)) { - SWIG_exception_fail(SWIG_ArgError(ecode2), "in method '" "ImmutableNBestSentencePieceText_nbests" "', argument " "2"" of type '" "int""'"); + SWIG_exception_fail(SWIG_ArgError(ecode2), "in method '" "ImmutableNBestSentencePieceText__nbests" "', argument " "2"" of type '" "int""'"); } arg2 = static_cast< int >(val2); { try { - result = sentencepiece_ImmutableNBestSentencePieceText_nbests((sentencepiece::ImmutableNBestSentencePieceText const *)arg1,arg2); + result = sentencepiece_ImmutableNBestSentencePieceText__nbests((sentencepiece::ImmutableNBestSentencePieceText const *)arg1,arg2); ReleaseResultObject(resultobj); } catch (const sentencepiece::util::Status &status) { @@ -6822,6 +6903,87 @@ fail: } +SWIGINTERN PyObject *_wrap_SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + std::vector< std::vector< int > > *arg2 = 0 ; + int arg3 ; + void *argp1 = 0 ; + int res1 = 0 ; + int val3 ; + int ecode3 = 0 ; + PyObject *swig_obj[3] ; + SwigValueWrapper< std::vector< sentencepiece::ImmutableSentencePieceText > > result; + + if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch", 3, 3, swig_obj)) SWIG_fail; + res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + std::vector> *out = nullptr; + if (PyList_Check(swig_obj[1])) { + const size_t size = PyList_Size(swig_obj[1]); + out = new std::vector>(size); + for (size_t i = 0; i < size; ++i) { + PyObject *o = PyList_GetItem(swig_obj[1], i); + if (PyList_Check(o)) { + const size_t size2 = PyList_Size(o); + (*out)[i].resize(size2); + for (size_t j = 0; j < size2; ++j) { + PyObject *o2 = PyList_GetItem(o, j); + if (PyInt_Check(o2)) { + (*out)[i][j] = static_cast(PyInt_AsLong(o2)); + } else { + PyErr_SetString(PyExc_TypeError, "list must contain strings"); + SWIG_fail; + } + } + } else { + PyErr_SetString(PyExc_TypeError, "not a list"); + SWIG_fail; + } + } + } else { + PyErr_SetString(PyExc_TypeError,"not a list"); + SWIG_fail; + } + arg2 = out; + } + ecode3 = SWIG_AsVal_int(swig_obj[2], &val3); + if (!SWIG_IsOK(ecode3)) { + SWIG_exception_fail(SWIG_ArgError(ecode3), "in method '" "SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch" "', argument " "3"" of type '" "int""'"); + } + arg3 = static_cast< int >(val3); + { + try { + result = sentencepiece_SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch((sentencepiece::SentencePieceProcessor const *)arg1,(std::vector< std::vector< int > > const &)*arg2,arg3); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + resultobj = PyList_New((&result)->size()); + for (size_t i = 0; i < (&result)->size(); ++i) { + PyObject *obj = SWIG_NewPointerObj(new sentencepiece::ImmutableSentencePieceText((&result)->at(i)), SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText, SWIG_POINTER_OWN | 0); + PyList_SET_ITEM(resultobj, i, obj); + } + } + { + delete arg2; + } + return resultobj; +fail: + { + delete arg2; + } + return NULL; +} + + SWIGINTERN PyObject *_wrap_SentencePieceProcessor__DecodePiecesBatch(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; @@ -8298,17 +8460,19 @@ static PyMethodDef SwigMethods[] = { { "new_ImmutableSentencePieceText", _wrap_new_ImmutableSentencePieceText, METH_NOARGS, NULL}, { "delete_ImmutableSentencePieceText", _wrap_delete_ImmutableSentencePieceText, METH_O, NULL}, { "ImmutableSentencePieceText_pieces_size", _wrap_ImmutableSentencePieceText_pieces_size, METH_O, NULL}, + { "ImmutableSentencePieceText_pieces", _wrap_ImmutableSentencePieceText_pieces, METH_VARARGS, NULL}, { "ImmutableSentencePieceText_text", _wrap_ImmutableSentencePieceText_text, METH_O, NULL}, { "ImmutableSentencePieceText_score", _wrap_ImmutableSentencePieceText_score, METH_O, NULL}, { "ImmutableSentencePieceText_SerializeAsString", _wrap_ImmutableSentencePieceText_SerializeAsString, METH_O, NULL}, - { "ImmutableSentencePieceText_pieces", _wrap_ImmutableSentencePieceText_pieces, METH_VARARGS, NULL}, + { "ImmutableSentencePieceText__pieces", _wrap_ImmutableSentencePieceText__pieces, METH_VARARGS, NULL}, { "ImmutableSentencePieceText_swigregister", ImmutableSentencePieceText_swigregister, METH_O, NULL}, { "ImmutableSentencePieceText_swiginit", ImmutableSentencePieceText_swiginit, METH_VARARGS, NULL}, { "new_ImmutableNBestSentencePieceText", _wrap_new_ImmutableNBestSentencePieceText, METH_NOARGS, NULL}, { "delete_ImmutableNBestSentencePieceText", _wrap_delete_ImmutableNBestSentencePieceText, METH_O, NULL}, { "ImmutableNBestSentencePieceText_nbests_size", _wrap_ImmutableNBestSentencePieceText_nbests_size, METH_O, NULL}, - { "ImmutableNBestSentencePieceText_SerializeAsString", _wrap_ImmutableNBestSentencePieceText_SerializeAsString, METH_O, NULL}, { "ImmutableNBestSentencePieceText_nbests", _wrap_ImmutableNBestSentencePieceText_nbests, METH_VARARGS, NULL}, + { "ImmutableNBestSentencePieceText_SerializeAsString", _wrap_ImmutableNBestSentencePieceText_SerializeAsString, METH_O, NULL}, + { "ImmutableNBestSentencePieceText__nbests", _wrap_ImmutableNBestSentencePieceText__nbests, METH_VARARGS, NULL}, { "ImmutableNBestSentencePieceText_swigregister", ImmutableNBestSentencePieceText_swigregister, METH_O, NULL}, { "ImmutableNBestSentencePieceText_swiginit", ImmutableNBestSentencePieceText_swiginit, METH_VARARGS, NULL}, { "new_SentencePieceProcessor", _wrap_new_SentencePieceProcessor, METH_NOARGS, NULL}, @@ -8350,6 +8514,7 @@ static PyMethodDef SwigMethods[] = { { "SentencePieceProcessor__DecodePiecesAsImmutableProto", _wrap_SentencePieceProcessor__DecodePiecesAsImmutableProto, METH_VARARGS, NULL}, { "SentencePieceProcessor__DecodeIdsBatch", _wrap_SentencePieceProcessor__DecodeIdsBatch, METH_VARARGS, NULL}, { "SentencePieceProcessor__DecodeIdsAsSerializedProtoBatch", _wrap_SentencePieceProcessor__DecodeIdsAsSerializedProtoBatch, METH_VARARGS, NULL}, + { "SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch", _wrap_SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch, METH_VARARGS, NULL}, { "SentencePieceProcessor__DecodePiecesBatch", _wrap_SentencePieceProcessor__DecodePiecesBatch, METH_VARARGS, NULL}, { "SentencePieceProcessor__DecodePiecesAsSerializedProtoBatch", _wrap_SentencePieceProcessor__DecodePiecesAsSerializedProtoBatch, METH_VARARGS, NULL}, { "SentencePieceProcessor__DecodePiecesAsImmutableProtoBatch", _wrap_SentencePieceProcessor__DecodePiecesAsImmutableProtoBatch, METH_VARARGS, NULL}, diff --git a/python/test/sentencepiece_test.py b/python/test/sentencepiece_test.py index 2f2c84a..5e4af7f 100755 --- a/python/test/sentencepiece_test.py +++ b/python/test/sentencepiece_test.py @@ -266,6 +266,13 @@ class TestSentencepieceProcessor(unittest.TestCase): t4 = self.sp_.decode_pieces_as_serialized_proto(['foo', 'bar']) t5 = self.sp_.decode_ids_as_serialized_proto([20, 30]) + y1 = self.sp_.encode(text, out_type='serialized_proto') + y2 = self.sp_.encode( + text, enable_sampling=True, out_type='serialized_proto') + y3 = self.sp_.nbest_encode(text, out_type='serialized_proto', nbest_size=10) + y4 = self.sp_.decode(['foo', 'bar'], out_type='serialized_proto') + y5 = self.sp_.decode([20, 30], out_type='serialized_proto') + self.assertEqual(type(s1), bytes) self.assertEqual(type(s2), bytes) self.assertEqual(type(t2), bytes) @@ -277,6 +284,92 @@ class TestSentencepieceProcessor(unittest.TestCase): self.assertEqual(s3, t3) self.assertEqual(s4, t4) self.assertEqual(s5, t5) + self.assertEqual(s1, y1) + self.assertEqual(s3, y3) + self.assertEqual(s4, y4) + self.assertEqual(s5, y5) + + ids = self.jasp_.EncodeAsIds(text) + pieces = self.jasp_.EncodeAsPieces(text) + s1 = self.jasp_.EncodeAsSerializedProto(text) + s2 = self.jasp_.DecodeIdsAsSerializedProto(ids) + s3 = self.jasp_.DecodePiecesAsSerializedProto(ids) + self.assertEqual(s2, s1) + self.assertEqual(s3, s1) + + def test_immutable_proto(self): + text = 'I saw a girl with a telescope.' + s1 = self.sp_.EncodeAsImmutableProto(text) + s2 = self.sp_.SampleEncodeAsImmutableProto(text, 10, 0.2) + s3 = self.sp_.NBestEncodeAsImmutableProto(text, 10) + s4 = self.sp_.DecodePiecesAsImmutableProto(['foo', 'bar']) + s5 = self.sp_.DecodeIdsAsImmutableProto([20, 30]) + + t1 = self.sp_.encode_as_immutable_proto(text) + t2 = self.sp_.sample_encode_as_immutable_proto(text, 10, 0.2) + t3 = self.sp_.nbest_encode_as_immutable_proto(text, 10) + t4 = self.sp_.decode_pieces_as_immutable_proto(['foo', 'bar']) + t5 = self.sp_.decode_ids_as_immutable_proto([20, 30]) + + y1 = self.sp_.encode(text, out_type='immutable_proto') + y2 = self.sp_.encode(text, enable_sampling=True, out_type='immutable_proto') + y3 = self.sp_.nbest_encode(text, out_type='immutable_proto', nbest_size=10) + y4 = self.sp_.decode(['foo', 'bar'], out_type='immutable_proto') + y5 = self.sp_.decode([20, 30], out_type='immutable_proto') + + self.assertEqual(s1, t1) + self.assertEqual(s3, t3) + self.assertEqual(s4, t4) + self.assertEqual(s5, t5) + self.assertEqual(s1, y1) + self.assertEqual(s3, y3) + self.assertEqual(s4, y4) + self.assertEqual(s5, y5) + + x1 = self.sp_.encode_as_serialized_proto(text) + x2 = self.sp_.sample_encode_as_serialized_proto(text, 10, 0.2) + x3 = self.sp_.nbest_encode_as_serialized_proto(text, 10) + x4 = self.sp_.decode_pieces_as_serialized_proto(['foo', 'bar']) + x5 = self.sp_.decode_ids_as_serialized_proto([20, 30]) + + self.assertEqual(x1, t1.SerializeAsString()) + self.assertEqual(x3, t3.SerializeAsString()) + self.assertEqual(x4, t4.SerializeAsString()) + self.assertEqual(x5, t5.SerializeAsString()) + + v1 = self.sp_.EncodeAsIds(text) + v2 = self.sp_.EncodeAsPieces(text) + self.assertEqual([x.id() for x in s1], v1) + self.assertEqual([x.piece() for x in s1], v2) + self.assertEqual(text, s1.text()) + + surfaces1 = [s1.text()[x.begin():x.end()] for x in s1] + surfaces2 = [x.surface() for x in s1] + self.assertEqual(surfaces1, surfaces2) + + ids = [] + for i in range(s1.pieces_size()): + ids.append(s1.pieces(i).id()) + self.assertEqual(ids, v1) + + pieces = [] + for i in range(s1.pieces_size()): + pieces.append(s1.pieces(i).piece()) + self.assertEqual(pieces, v2) + + # Japanese offset + s1 = self.jasp_.EncodeAsImmutableProto('吾輩は猫である。Hello world. ABC 123') + surfaces1 = [s1.text()[x.begin():x.end()] for x in s1] + surfaces2 = [x.surface() for x in s1] + self.assertEqual(surfaces1, surfaces2) + + ids = [x.id() for x in s1] + s2 = self.jasp_.DecodeIdsAsImmutableProto(ids) + self.assertEqual(s2, s1) + + pieces = [x.piece() for x in s1] + s2 = self.jasp_.DecodePiecesAsImmutableProto(pieces) + self.assertEqual(s2, s1) def test_new_api(self): sp = spm.SentencePieceProcessor( @@ -386,49 +479,102 @@ class TestSentencepieceProcessor(unittest.TestCase): self.assertEqual(pieces, sp.encode(text, add_bos=False, add_eos=True)) def test_sampling(self): - sp = spm.SentencePieceProcessor( - model_file=os.path.join('test', 'test_model.model'), - out_type=str, - enable_sampling=True) - ids = defaultdict(int) - for n in range(100): - ++ids[' '.join(sp.encode('hello world'))] - self.assertGreater(len(ids), 1) - - ids2 = defaultdict(int) - for n in range(100): - ++ids2[' '.join(sp.encode('hello world', enable_sampling=False))] - self.assertEqual(len(ids2), 1) + sp = self.sp_ + + for out_type in [str, int, 'serialized_proto', 'immutable_proto']: + ids = defaultdict(int) + for n in range(100): + out = sp.encode('hello world', out_type=out_type, enable_sampling=True) + if type(out) is list: + out = tuple(out) + ++ids[out] + self.assertGreater(len(ids), 1) + + ids2 = defaultdict(int) + for n in range(100): + out = sp.encode('hello world', out_type=out_type, enable_sampling=False) + if type(out) is list: + out = tuple(out) + ++ids2[out] + self.assertEqual(len(ids2), 1) + + out = sp.encode(['hello world', 'this is a test'], + out_type=out_type, + enable_sampling=True) + self.assertEqual(len(out), 2) + out = sp.encode(['hello world', 'this is a test'], + out_type=out_type, + enable_sampling=False) + self.assertEqual(len(out), 2) def test_nbest(self): - sp = spm.SentencePieceProcessor( - model_file=os.path.join('test', 'test_model.model')) + sp = self.sp_ text = 'hello world' - results = sp.nbest_encode(text, nbest_size=10, out_type=str) - self.assertEqual(results, sp.NBestEncode(text, nbest_size=10, out_type=str)) - for n in results: - self.assertEqual(sp.decode(n), text) - decoded = sp.decode(results) - for n in decoded: - self.assertEqual(n, text) - results = sp.nbest_encode(text, nbest_size=10, out_type=int) - self.assertEqual(results, sp.NBestEncode(text, nbest_size=10, out_type=int)) - for n in results: - self.assertEqual(sp.decode(n), text) - decoded = sp.decode(results) - for n in decoded: - self.assertEqual(n, text) + text2 = 'I have a pen.' + + for out_type in [str, int, 'serialized_proto', 'immutable_proto']: + results = sp.nbest_encode(text, nbest_size=10, out_type=out_type) + self.assertEqual(results, + sp.NBestEncode(text, nbest_size=10, out_type=out_type)) + + if out_type in [str, int]: + for n in results: + self.assertEqual(sp.decode(n), text) + + for n in sp.decode(results): + self.assertEqual(n, text) + + # batch test + results = sp.nbest_encode([text, text2], nbest_size=10, out_type=out_type) + self.assertEqual( + results, + sp.NBestEncode([text, text2], nbest_size=10, out_type=out_type)) + self.assertEqual(len(results), 2) + + if out_type in [str, int]: + for n in results[0]: + self.assertEqual(sp.decode(n), text) + + for n in results[1]: + self.assertEqual(sp.decode(n), text2) + + decoded = sp.decode(results[0]) + self.assertEqual(len(decoded), 10) + for n in decoded: + self.assertEqual(n, text) + decoded = sp.decode(results[1]) + self.assertEqual(len(decoded), 10) + for n in decoded: + self.assertEqual(n, text2) def test_sample_and_score(self): - sp = spm.SentencePieceProcessor( - model_file=os.path.join('test', 'test_model.model')) + sp = self.sp_ text = 'hello world' - results = sp.sample_encode_and_score(text, wor=True, out_type=str) - for n in results: - self.assertEqual(sp.decode(n[0]), text) - results = sp.sample_encode_and_score(text, wor=True, out_type=int) - for n in results: - self.assertEqual(sp.decode(n[0]), text) + text2 = 'I have a pen.' + for out_type in [str, int, 'serialized_proto', 'immutable_proto']: + results = sp.sample_encode_and_score( + text, wor=True, num_samples=10, out_type=out_type) + results = sp.SampleEncodeAndScore( + text, wor=False, num_samples=10, out_type=out_type) + + if out_type in [str, int]: + for n in results: + self.assertEqual(sp.decode(n[0]), text) + + results = sp.sample_encode_and_score([text, text2], + wor=True, + num_samples=10, + out_type=out_type) + results = sp.SampleEncodeAndScore([text, text2], + wor=True, + num_samples=10, + out_type=out_type) + + if out_type in [str, int]: + for n in results[0]: + self.assertEqual(sp.decode(n[0]), text) + for n in results[1]: + self.assertEqual(sp.decode(n[0]), text2) def test_valid_range(self): size = self.sp_.piece_size() @@ -452,65 +598,28 @@ class TestSentencepieceProcessor(unittest.TestCase): with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file: texts = file.readlines() - r1 = sp.encode(texts, out_type=str, num_threads=None) - r2 = sp.encode(texts, out_type=str, num_threads=1) - r3 = sp.encode(texts, out_type=str, num_threads=-1) - r4 = sp.encode(texts, out_type=str, num_threads=8) - r5 = [sp.encode(s, out_type=str) for s in texts] - self.assertEqual(r1, r2) - self.assertEqual(r1, r3) - self.assertEqual(r1, r4) - self.assertEqual(r1, r5) - - d1 = sp.decode(r1, num_threads=None) - d2 = sp.decode(r2, num_threads=1) - d3 = sp.decode(r3, num_threads=-1) - d4 = sp.decode(r4, num_threads=8) - d5 = [sp.decode(s) for s in r5] - self.assertEqual(d1, d2) - self.assertEqual(d1, d3) - self.assertEqual(d1, d4) - self.assertEqual(d1, d5) - - r1 = sp.encode(texts, out_type=int, num_threads=None) - r2 = sp.encode(texts, out_type=int, num_threads=1) - r3 = sp.encode(texts, out_type=int, num_threads=-1) - r4 = sp.encode(texts, out_type=int, num_threads=8) - r5 = [sp.encode(s, out_type=int) for s in texts] - self.assertEqual(r1, r2) - self.assertEqual(r1, r3) - self.assertEqual(r1, r4) - self.assertEqual(r1, r5) - - d1 = sp.decode(r1, num_threads=None) - d2 = sp.decode(r2, num_threads=1) - d3 = sp.decode(r3, num_threads=-1) - d4 = sp.decode(r4, num_threads=8) - d5 = [sp.decode(s) for s in r5] - self.assertEqual(d1, d2) - self.assertEqual(d1, d3) - self.assertEqual(d1, d4) - self.assertEqual(d1, d5) - - r1 = sp.encode(texts, out_type='serialized_proto', num_threads=None) - r2 = sp.encode(texts, out_type='serialized_proto', num_threads=1) - r3 = sp.encode(texts, out_type='serialized_proto', num_threads=-1) - r4 = sp.encode(texts, out_type='serialized_proto', num_threads=8) - r5 = [sp.encode(s, out_type='serialized_proto') for s in texts] - self.assertEqual(r1, r2) - self.assertEqual(r1, r3) - self.assertEqual(r1, r4) - self.assertEqual(r1, r5) - - r1 = sp.encode(texts, out_type='immutable_proto', num_threads=None) - r2 = sp.encode(texts, out_type='immutable_proto', num_threads=1) - r3 = sp.encode(texts, out_type='immutable_proto', num_threads=-1) - r4 = sp.encode(texts, out_type='immutable_proto', num_threads=8) - r5 = [sp.encode(s, out_type='immutable_proto') for s in texts] - self.assertEqual(r1, r2) - self.assertEqual(r1, r3) - self.assertEqual(r1, r4) - self.assertEqual(r1, r5) + for out_type in [str, int, 'serialized_proto', 'immutable_proto']: + r1 = sp.encode(texts, out_type=out_type, num_threads=None) + r2 = sp.encode(texts, out_type=out_type, num_threads=1) + r3 = sp.encode(texts, out_type=out_type, num_threads=-1) + r4 = sp.encode(texts, out_type=out_type, num_threads=8) + r5 = [sp.encode(s, out_type=out_type) for s in texts] + self.assertEqual(r1, r2) + self.assertEqual(r1, r3) + self.assertEqual(r1, r4) + self.assertEqual(r1, r5) + + if out_type in [str, int]: + d1 = sp.decode(r1, num_threads=None) + d2 = sp.decode(r2, num_threads=1) + d3 = sp.decode(r3, num_threads=-1) + d4 = sp.decode(r4, num_threads=8) + d5 = [sp.decode(s) for s in r5] + + self.assertEqual(d1, d2) + self.assertEqual(d1, d3) + self.assertEqual(d1, d4) + self.assertEqual(d1, d5) e1 = sp.calculate_entropy(texts, alpha=1.0, num_threads=10) e2 = sp.CalculateEntropy(texts, alpha=1.0, num_threads=10) diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 482a45b..2a5c399 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -55,6 +55,34 @@ std::vector ToPieceArray(const std::vector &v) { return out; } +void ConvertToUnicodeSpansInternal(SentencePieceText *spt) { + if (spt == nullptr) return; + + std::vector utf8_to_unicode(spt->text().size() + 1, 0); + absl::string_view str = spt->text(); + size_t prev = 0; + int ulen = 0; + while (!str.empty()) { + const size_t mblen = string_util::OneCharLen(str.data()); + for (int i = prev; i < prev + mblen; ++i) { + utf8_to_unicode[i] = ulen; + } + ++ulen; + prev += mblen; + str.remove_prefix(mblen); + } + utf8_to_unicode[prev] = ulen; + + auto clip = [&](int s) { + return std::min(std::max(0, s), utf8_to_unicode.size() - 1); + }; + + for (auto &piece : *(spt->mutable_pieces())) { + piece.set_begin(utf8_to_unicode[clip(piece.begin())]); + piece.set_end(utf8_to_unicode[clip(piece.end())]); + } +} + } // namespace ImmutableSentencePieceText::ImmutableSentencePieceText() @@ -132,6 +160,10 @@ SentencePieceText *ImmutableSentencePieceText::mutable_proto() { return rep_.get(); } +void ImmutableSentencePieceText::ConvertToUnicodeSpans() { + ConvertToUnicodeSpansInternal(mutable_proto()); +} + util::bytes ImmutableSentencePieceText::SerializeAsString() const { return spt_->SerializeAsString(); } @@ -164,6 +196,13 @@ NBestSentencePieceText *ImmutableNBestSentencePieceText::mutable_proto() { return rep_.get(); } +void ImmutableNBestSentencePieceText::ConvertToUnicodeSpans() { + if (!mutable_proto()) return; + for (auto &spt : *(mutable_proto()->mutable_nbests())) { + ConvertToUnicodeSpansInternal(&spt); + } +} + util::bytes ImmutableNBestSentencePieceText::SerializeAsString() const { return rep_ ? rep_->SerializeAsString() : ""; } @@ -1048,34 +1087,6 @@ std::string SentencePieceProcessor::serialized_model_proto() const { // std::random_device. void SetRandomGeneratorSeed(unsigned int seed); -void ConvertToUnicodeSpans(SentencePieceText *spt) { - if (spt == nullptr) return; - - std::vector utf8_to_unicode(spt->text().size() + 1, 0); - absl::string_view str = spt->text(); - size_t prev = 0; - int ulen = 0; - while (!str.empty()) { - const size_t mblen = string_util::OneCharLen(str.data()); - for (int i = prev; i < prev + mblen; ++i) { - utf8_to_unicode[i] = ulen; - } - ++ulen; - prev += mblen; - str.remove_prefix(mblen); - } - utf8_to_unicode[prev] = ulen; - - auto clip = [&](int s) { - return std::min(std::max(0, s), utf8_to_unicode.size() - 1); - }; - - for (auto &piece : *(spt->mutable_pieces())) { - piece.set_begin(utf8_to_unicode[clip(piece.begin())]); - piece.set_end(utf8_to_unicode[clip(piece.end())]); - } -} - namespace io { util::Status LoadModelProto(absl::string_view filename, ModelProto *model_proto) { diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index b7fae6a..d107a2a 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -25,8 +25,8 @@ #ifndef SWIG namespace absl { using std::string_view; -} -#endif // SWIG +} // namespace absl +#endif namespace sentencepiece { namespace util { @@ -196,6 +196,9 @@ class ImmutableSentencePieceText { // it returns the raw pointer managed by the shared_ptr. SentencePieceText *mutable_proto(); + // Converts the utf8 byte spans into Unicode char span. + void ConvertToUnicodeSpans(); + friend class ImmutableNBestSentencePieceText; private: @@ -225,6 +228,8 @@ class ImmutableNBestSentencePieceText { // it returns the raw pointer managed by the shared_ptr. NBestSentencePieceText *mutable_proto(); + void ConvertToUnicodeSpans(); + private: std::shared_ptr rep_; }; @@ -415,14 +420,16 @@ class SentencePieceProcessor { virtual util::Status Decode(const std::vector &ids, SentencePieceText *spt) const; -#ifdef SWIG +#ifdef SWIGPYTHON +#define CONVERT_TO_UNICODE_SPAN output.ConvertToUnicodeSpans(); #define SPP_SWIG_CHECK_AND_THROW \ if (!status.ok()) throw status; #else +#define CONVERT_TO_UNICODE_SPAN #define SPP_SWIG_CHECK_AND_THROW \ if (!status.ok()) { \ } -#endif // SWIG +#endif // SWIGPYTHON #define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \ OutType output; \ @@ -439,6 +446,7 @@ class SentencePieceProcessor { #define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...) \ OutType output; \ const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \ + CONVERT_TO_UNICODE_SPAN; \ SPP_SWIG_CHECK_AND_THROW; \ return output; @@ -707,9 +715,6 @@ class SentencePieceProcessor { // std::random_device. void SetRandomGeneratorSeed(unsigned int seed); -// Converts the utf8 byte spans into Unicode char span. -void ConvertToUnicodeSpans(SentencePieceText *spt); - #ifndef SWIG // IO related functions to absorb model formats. namespace io { diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index ff55aeb..f05dc5d 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -1657,11 +1657,12 @@ TEST(SentencePieceProcessorTest, ImmutableNBestSentencePieceTextTest) { TEST(SentencePieceProcessorTest, ConvertToUnicodeSpansTest) { auto make_spt = [&](const std::vector &tokens) { - SentencePieceText spt; + ImmutableSentencePieceText ispt; + auto *spt = ispt.mutable_proto(); int prev = 0; std::string text; for (const auto &tok : tokens) { - auto *piece = spt.add_pieces(); + auto *piece = spt->add_pieces(); piece->set_surface(tok); piece->set_piece(tok); piece->set_begin(prev); @@ -1669,9 +1670,9 @@ TEST(SentencePieceProcessorTest, ConvertToUnicodeSpansTest) { prev += tok.size(); text += tok; } - spt.set_text(text); - ConvertToUnicodeSpans(&spt); - return spt; + spt->set_text(text); + ispt.ConvertToUnicodeSpans(); + return ispt; }; {