support slice in pieces/nbests objects
authorTaku Kudo <taku@google.com>
Fri, 5 Aug 2022 07:34:44 +0000 (16:34 +0900)
committerKentaro Hayashi <kenhys@xdump.org>
Mon, 21 Nov 2022 13:43:46 +0000 (13:43 +0000)
Signed-off-by: Kentaro Hayashi <kenhys@gmail.com>
Gbp-Pq: Name 0019-support-slice-in-pieces-nbests-objects.patch

python/src/sentencepiece/__init__.py
python/src/sentencepiece/sentencepiece.i
python/test/sentencepiece_test.py

index ce9d60dab649894b18ccb72f6d9d88daf091e8c0..cf06830e62d277ee5a56651b20b77e8800f00e94 100644 (file)
@@ -145,6 +145,10 @@ class ImmutableSentencePieceText(object):
         return self.len
 
       def __getitem__(self, index):
+        if isinstance(index, slice):
+          return [self.proto._pieces(i) for i in range(self.len)][index.start:index.stop:index.step]
+        if index < 0:
+          index = index + self.len
         if index < 0 or index >= self.len:
           raise IndexError('piece index is out of range')
         return self.proto._pieces(index)
@@ -202,6 +206,10 @@ class ImmutableNBestSentencePieceText(object):
         return self.len
 
       def __getitem__(self, index):
+        if isinstance(index, slice):
+          return [self.proto._nbests(i) for i in range(self.len)][index.start:index.stop:index.step]
+        if index < 0:
+          index = index + self.len
         if index < 0 or index >= self.len:
           raise IndexError('nbests index is out of range')
         return self.proto._nbests(index)
index e22f76330c6f799c967a7bf79e64e91cc919fdb8..2ac68a8f99433e95549d4eb9c88f6d04ed5c3ce8 100644 (file)
@@ -1293,6 +1293,10 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
         return self.len
 
       def __getitem__(self, index):
+        if isinstance(index, slice):
+          return [self.proto._pieces(i) for i in range(self.len)][index.start:index.stop:index.step]
+        if index < 0:
+          index = index + self.len
         if index < 0 or index >= self.len:
           raise IndexError('piece index is out of range')
         return self.proto._pieces(index)
@@ -1336,6 +1340,10 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
         return self.len
 
       def __getitem__(self, index):
+        if isinstance(index, slice):
+          return [self.proto._nbests(i) for i in range(self.len)][index.start:index.stop:index.step]
+        if index < 0:
+          index = index + self.len
         if index < 0 or index >= self.len:
           raise IndexError('nbests index is out of range')
         return self.proto._nbests(index)
index 6cbe077d2179d41b8b61d3f4bb71493d3f544168..92327ac3e911d80f7cc70ae2c513a6d378604e2e 100755 (executable)
@@ -395,6 +395,10 @@ class TestSentencepieceProcessor(unittest.TestCase):
       self.assertEqual(
           self.sp_.Decode([x.id for x in s3.nbests[i].pieces]), text)
 
+    # slice
+    self.assertEqual(s1.pieces[::-1], list(reversed(s1.pieces)))
+    self.assertEqual(s3.nbests[::-1], list(reversed(s3.nbests)))
+
     # Japanese offset
     s1 = self.jasp_.EncodeAsImmutableProto('吾輩は猫である。Hello world. ABC 123')
     surfaces1 = [s1.text[x.begin:x.end] for x in s1.pieces]