From fb547400f99b8c63fc83ef8a4ae8f461631308f3 Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Wed, 3 Aug 2022 17:20:01 +0900 Subject: [PATCH] Fixed test failure. Signed-off-by: Kentaro Hayashi Gbp-Pq: Name 0016-Fixed-test-failure.patch --- python/src/sentencepiece/sentencepiece.i | 35 ++++++++++++++++--- .../src/sentencepiece/sentencepiece_wrap.cxx | 35 ++++++++++++++++--- src/sentencepiece_processor.cc | 4 +-- src/sentencepiece_processor.h | 34 ++++++------------ 4 files changed, 72 insertions(+), 36 deletions(-) diff --git a/python/src/sentencepiece/sentencepiece.i b/python/src/sentencepiece/sentencepiece.i index 75f62c8..1a94fef 100644 --- a/python/src/sentencepiece/sentencepiece.i +++ b/python/src/sentencepiece/sentencepiece.i @@ -193,6 +193,19 @@ inline void CheckIds(const std::vector &ids, int num_pieces) { inline void CheckIds(const std::vector &ids, int num_pieces) {} +template +inline void ConvertToUnicodeSpans(T *proto) {} + +template <> +inline void ConvertToUnicodeSpans(sentencepiece::ImmutableSentencePieceText *proto) { + proto->ConvertToUnicodeSpans(); +} + +template <> +inline void ConvertToUnicodeSpans(sentencepiece::ImmutableNBestSentencePieceText *proto) { + proto->ConvertToUnicodeSpans(); +} + class ThreadPool { public: explicit ThreadPool(size_t request_size) : @@ -239,6 +252,7 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { self->FuncName(ins[i]); \ RewriteIds(*self, &out, add_bos, add_eos, reverse, \ emit_unk_piece); \ + ConvertToUnicodeSpans(&out); \ outs[i] = std::move(out); \ } \ }); \ @@ -255,7 +269,9 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { pool.Schedule([&, n]() { \ for (size_t i = n; i < ins.size(); i += num_threads) { \ CheckIds(ins[i], self->GetPieceSize()); \ - outs[i] = self->FuncName(ins[i]); \ + auto out = self->FuncName(ins[i]); \ + ConvertToUnicodeSpans(&out); \ + outs[i] = std::move(out); \ } \ }); \ } \ @@ -396,6 +412,7 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { auto proto = enable_sampling ? $self->SampleEncodeAsImmutableProto(text, nbest_size, alpha) : $self->EncodeAsImmutableProto(text); + proto.ConvertToUnicodeSpans(); RewriteIds(*$self, &proto, add_bos, add_eos, reverse, emit_unk_piece); return proto; } @@ -467,13 +484,17 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { sentencepiece::ImmutableSentencePieceText _DecodeIdsAsImmutableProto( const std::vector &ids) const { CheckIds(ids, $self->GetPieceSize()); - return $self->DecodeIdsAsImmutableProto(ids); + auto proto = $self->DecodeIdsAsImmutableProto(ids); + proto.ConvertToUnicodeSpans(); + return proto; } sentencepiece::ImmutableSentencePieceText _DecodePiecesAsImmutableProto( const std::vector &pieces) const { CheckIds(pieces, $self->GetPieceSize()); - return $self->DecodePiecesAsImmutableProto(pieces); + auto proto= $self->DecodePiecesAsImmutableProto(pieces); + proto.ConvertToUnicodeSpans(); + return proto; } ///////////////////////////////////////////////////////////////////////////// @@ -557,7 +578,9 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { bool emit_unk_piece) const { RewriteIds(*$self, static_cast(nullptr), add_bos, add_eos, reverse, emit_unk_piece); - return $self->NBestEncodeAsImmutableProto(text, nbest_size); + auto proto = $self->NBestEncodeAsImmutableProto(text, nbest_size); + proto.ConvertToUnicodeSpans(); + return proto; } @@ -611,8 +634,10 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { bool emit_unk_piece) const { RewriteIds(*$self, static_cast(nullptr), add_bos, add_eos, reverse, emit_unk_piece); - return $self->SampleEncodeAndScoreAsImmutableProto(text, num_samples, + auto proto = $self->SampleEncodeAndScoreAsImmutableProto(text, num_samples, alpha, wor, include_best); + proto.ConvertToUnicodeSpans(); + return proto; } diff --git a/python/src/sentencepiece/sentencepiece_wrap.cxx b/python/src/sentencepiece/sentencepiece_wrap.cxx index 22e0708..4b8b5ef 100644 --- a/python/src/sentencepiece/sentencepiece_wrap.cxx +++ b/python/src/sentencepiece/sentencepiece_wrap.cxx @@ -3002,6 +3002,19 @@ inline void CheckIds(const std::vector &ids, int num_pieces) { inline void CheckIds(const std::vector &ids, int num_pieces) {} +template +inline void ConvertToUnicodeSpans(T *proto) {} + +template <> +inline void ConvertToUnicodeSpans(sentencepiece::ImmutableSentencePieceText *proto) { + proto->ConvertToUnicodeSpans(); +} + +template <> +inline void ConvertToUnicodeSpans(sentencepiece::ImmutableNBestSentencePieceText *proto) { + proto->ConvertToUnicodeSpans(); +} + class ThreadPool { public: explicit ThreadPool(size_t request_size) : @@ -3048,6 +3061,7 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { self->FuncName(ins[i]); \ RewriteIds(*self, &out, add_bos, add_eos, reverse, \ emit_unk_piece); \ + ConvertToUnicodeSpans(&out); \ outs[i] = std::move(out); \ } \ }); \ @@ -3064,7 +3078,9 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { pool.Schedule([&, n]() { \ for (size_t i = n; i < ins.size(); i += num_threads) { \ CheckIds(ins[i], self->GetPieceSize()); \ - outs[i] = self->FuncName(ins[i]); \ + auto out = self->FuncName(ins[i]); \ + ConvertToUnicodeSpans(&out); \ + outs[i] = std::move(out); \ } \ }); \ } \ @@ -3540,6 +3556,7 @@ SWIGINTERN sentencepiece::ImmutableSentencePieceText sentencepiece_SentencePiece auto proto = enable_sampling ? self->SampleEncodeAsImmutableProto(text, nbest_size, alpha) : self->EncodeAsImmutableProto(text); + proto.ConvertToUnicodeSpans(); RewriteIds(*self, &proto, add_bos, add_eos, reverse, emit_unk_piece); return proto; } @@ -3578,11 +3595,15 @@ SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor__Deco } SWIGINTERN sentencepiece::ImmutableSentencePieceText sentencepiece_SentencePieceProcessor__DecodeIdsAsImmutableProto(sentencepiece::SentencePieceProcessor const *self,std::vector< int > const &ids){ CheckIds(ids, self->GetPieceSize()); - return self->DecodeIdsAsImmutableProto(ids); + auto proto = self->DecodeIdsAsImmutableProto(ids); + proto.ConvertToUnicodeSpans(); + return proto; } SWIGINTERN sentencepiece::ImmutableSentencePieceText sentencepiece_SentencePieceProcessor__DecodePiecesAsImmutableProto(sentencepiece::SentencePieceProcessor const *self,std::vector< absl::string_view > const &pieces){ CheckIds(pieces, self->GetPieceSize()); - return self->DecodePiecesAsImmutableProto(pieces); + auto proto= self->DecodePiecesAsImmutableProto(pieces); + proto.ConvertToUnicodeSpans(); + return proto; } SWIGINTERN std::vector< std::string > sentencepiece_SentencePieceProcessor__DecodeIdsBatch(sentencepiece::SentencePieceProcessor const *self,std::vector< std::vector< int > > const &ins,int num_threads){ DEFINE_DECODE_BATCH_FUNC_IMPL(DecodeIds, int, std::string); @@ -3628,7 +3649,9 @@ SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor__NBes SWIGINTERN sentencepiece::ImmutableNBestSentencePieceText sentencepiece_SentencePieceProcessor__NBestEncodeAsImmutableProto(sentencepiece::SentencePieceProcessor const *self,absl::string_view text,int nbest_size,bool add_bos,bool add_eos,bool reverse,bool emit_unk_piece){ RewriteIds(*self, static_cast(nullptr), add_bos, add_eos, reverse, emit_unk_piece); - return self->NBestEncodeAsImmutableProto(text, nbest_size); + auto proto = self->NBestEncodeAsImmutableProto(text, nbest_size); + proto.ConvertToUnicodeSpans(); + return proto; } SWIGINTERN std::vector< std::pair< std::vector< int >,float > > sentencepiece_SentencePieceProcessor__SampleEncodeAndScoreAsIds(sentencepiece::SentencePieceProcessor const *self,absl::string_view text,int num_samples,float alpha,bool wor,bool include_best,bool add_bos,bool add_eos,bool reverse,bool emit_unk_piece){ auto idss = self->SampleEncodeAndScoreAsIds(text, num_samples, @@ -3655,8 +3678,10 @@ SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor__Samp SWIGINTERN sentencepiece::ImmutableNBestSentencePieceText sentencepiece_SentencePieceProcessor__SampleEncodeAndScoreAsImmutableProto(sentencepiece::SentencePieceProcessor const *self,absl::string_view text,int num_samples,float alpha,bool wor,bool include_best,bool add_bos,bool add_eos,bool reverse,bool emit_unk_piece){ RewriteIds(*self, static_cast(nullptr), add_bos, add_eos, reverse, emit_unk_piece); - return self->SampleEncodeAndScoreAsImmutableProto(text, num_samples, + auto proto = self->SampleEncodeAndScoreAsImmutableProto(text, num_samples, alpha, wor, include_best); + proto.ConvertToUnicodeSpans(); + return proto; } SWIGINTERN float sentencepiece_SentencePieceProcessor__CalculateEntropy(sentencepiece::SentencePieceProcessor *self,absl::string_view text,float alpha){ return self->CalculateEntropy(text, alpha); diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 2a5c399..f0df2f6 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -56,14 +56,14 @@ std::vector ToPieceArray(const std::vector &v) { } void ConvertToUnicodeSpansInternal(SentencePieceText *spt) { - if (spt == nullptr) return; + if (spt == nullptr || spt->text().empty()) return; std::vector utf8_to_unicode(spt->text().size() + 1, 0); absl::string_view str = spt->text(); size_t prev = 0; int ulen = 0; while (!str.empty()) { - const size_t mblen = string_util::OneCharLen(str.data()); + const size_t mblen = std::max(1, string_util::OneCharLen(str.data())); for (int i = prev; i < prev + mblen; ++i) { utf8_to_unicode[i] = ulen; } diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index be9449e..14b1e8c 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -419,47 +419,33 @@ class SentencePieceProcessor { virtual util::Status Decode(const std::vector &ids, SentencePieceText *spt) const; - -#ifndef SWIGPYTHON - -#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \ - OutType output; \ - const auto status = FuncName(__VA_ARGS__, &output); \ - return output; - -#define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...) \ - OutType output; \ - const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \ - return output.SerializeAsString(); - -#define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...) \ - OutType output; \ - const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \ - return output; - +#ifdef SWIG +#define SPP_SWIG_CHECK_AND_THROW \ + if (!status.ok()) throw status; #else +#define SPP_SWIG_CHECK_AND_THROW \ + if (!status.ok()) { \ + } +#endif // SWIG #define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \ OutType output; \ const auto status = FuncName(__VA_ARGS__, &output); \ - if (!status.ok()) throw status; \ + SPP_SWIG_CHECK_AND_THROW; \ return output; #define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...) \ OutType output; \ const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \ - if (!status.ok()) throw status; \ + SPP_SWIG_CHECK_AND_THROW; \ return output.SerializeAsString(); #define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...) \ OutType output; \ const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \ - if (!status.ok()) throw status; \ - output.ConvertToUnicodeSpans(); \ + SPP_SWIG_CHECK_AND_THROW; \ return output; -#endif // SWIGPYTHON - ////////////////////////////////////////////////////////////// // Handy methods that return the result directly. // These functions ignore internal errors. -- 2.30.2