[PATCH] Sync from Piper @473817856
authorMike Kruskal <mkruskal@google.com>
Mon, 12 Sep 2022 21:39:23 +0000 (14:39 -0700)
committerHelmut Grohne <helmut@subdivi.de>
Tue, 4 Apr 2023 15:09:31 +0000 (16:09 +0100)
PROTOBUF_SYNC_PIPER

Gbp-Pq: Name CVE-2022-1941.patch

src/google/protobuf/extension_set_heavy.cc
src/google/protobuf/wire_format.cc
src/google/protobuf/wire_format_unittest.cc

index a3c841671fb5f148282a27f8271b9d412e216a0e..fec0bfdb94a817a1192ce24107b0c074581dcd3d 100644 (file)
@@ -669,6 +669,8 @@ bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input,
   //   required data message = 3;
 
   uint32 last_type_id = 0;
+  enum class State { kNoTag, kHasType, kHasPayload, kDone };
+  State state = State::kNoTag;
 
   // If we see message data before the type_id, we'll append it to this so
   // we can parse it later.
@@ -682,9 +684,12 @@ bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input,
       case WireFormatLite::kMessageSetTypeIdTag: {
         uint32 type_id;
         if (!input->ReadVarint32(&type_id)) return false;
-        last_type_id = type_id;
 
-        if (!message_data.empty()) {
+        if (state == State::kNoTag) {
+          last_type_id = type_id;
+          state = State::kHasType;
+        } else if (state == State::kHasPayload) {
+          last_type_id = type_id;
           // We saw some message data before the type_id.  Have to parse it
           // now.
           io::CodedInputStream sub_input(
@@ -696,22 +701,26 @@ bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input,
             return false;
           }
           message_data.clear();
+          state = State::kDone;
         }
 
         break;
       }
 
       case WireFormatLite::kMessageSetMessageTag: {
-        if (last_type_id == 0) {
+        if (state != State::kHasType) {
           // We haven't seen a type_id yet.  Append this data to message_data.
           string temp;
           uint32 length;
           if (!input->ReadVarint32(&length)) return false;
           if (!input->ReadString(&temp, length)) return false;
-          io::StringOutputStream output_stream(&message_data);
-          io::CodedOutputStream coded_output(&output_stream);
-          coded_output.WriteVarint32(length);
-          coded_output.WriteString(temp);
+          if (state == State::kNoTag) {
+            io::StringOutputStream output_stream(&message_data);
+            io::CodedOutputStream coded_output(&output_stream);
+            coded_output.WriteVarint32(length);
+            coded_output.WriteString(temp);
+            state = State::kHasPayload;
+          }
         } else {
           // Already saw type_id, so we can parse this directly.
           if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
@@ -719,6 +728,7 @@ bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input,
                                      extension_finder, field_skipper)) {
             return false;
           }
+          state = State::kDone;
         }
 
         break;
index 3fdf84edef459c21bfc0529e5433d448d7e75c3d..b153e16c87d33258ea671ebc36eb8615115b5534 100644 (file)
@@ -725,10 +725,8 @@ bool WireFormat::ParseAndMergeMessageSetItem(
   //   required data message = 3;
 
   uint32 last_type_id = 0;
-
-  // Once we see a type_id, we'll look up the FieldDescriptor for the
-  // extension.
-  const FieldDescriptor* field = NULL;
+  enum class State { kNoTag, kHasType, kHasPayload, kDone };
+  State state = State::kNoTag;
 
   // If we see message data before the type_id, we'll append it to this so
   // we can parse it later.
@@ -742,12 +740,15 @@ bool WireFormat::ParseAndMergeMessageSetItem(
       case WireFormatLite::kMessageSetTypeIdTag: {
         uint32 type_id;
         if (!input->ReadVarint32(&type_id)) return false;
-        last_type_id = type_id;
-        field = message_reflection->FindKnownExtensionByNumber(type_id);
 
-        if (!message_data.empty()) {
+        if (state == State::kNoTag) {
+          last_type_id = type_id;
+          state = State::kHasType;
+        } else if (state == State::kHasPayload) {
           // We saw some message data before the type_id.  Have to parse it
           // now.
+          last_type_id = type_id;
+          const FieldDescriptor* field = message_reflection->FindKnownExtensionByNumber(type_id);
           io::ArrayInputStream raw_input(message_data.data(),
                                          message_data.size());
           io::CodedInputStream sub_input(&raw_input);
@@ -756,13 +757,14 @@ bool WireFormat::ParseAndMergeMessageSetItem(
             return false;
           }
           message_data.clear();
+          state = State::kDone;
         }
 
         break;
       }
 
       case WireFormatLite::kMessageSetMessageTag: {
-        if (last_type_id == 0) {
+        if (state == State::kNoTag) {
           // We haven't seen a type_id yet.  Append this data to message_data.
           string temp;
           uint32 length;
@@ -772,12 +774,17 @@ bool WireFormat::ParseAndMergeMessageSetItem(
           io::CodedOutputStream coded_output(&output_stream);
           coded_output.WriteVarint32(length);
           coded_output.WriteString(temp);
-        } else {
+          state = State::kHasPayload;
+        } else if (state == State::kHasType) {
           // Already saw type_id, so we can parse this directly.
+          const FieldDescriptor* field = message_reflection->FindKnownExtensionByNumber(last_type_id);
           if (!ParseAndMergeMessageSetField(last_type_id, field, message,
                                             input)) {
             return false;
           }
+          state = State::kDone;
+        } else {
+          if (!SkipField(input, tag, NULL)) return false;
         }
 
         break;
index 736a12828f5244576f2549522d29096e591520ff..64d39099e2a62a65e4a7daafde2bd7f2fe89d89d 100644 (file)
@@ -35,6 +35,7 @@
 #include <google/protobuf/wire_format.h>
 #include <google/protobuf/wire_format_lite_inl.h>
 #include <google/protobuf/descriptor.h>
+#include <google/protobuf/dynamic_message.h>
 #include <google/protobuf/io/zero_copy_stream_impl.h>
 #include <google/protobuf/io/coded_stream.h>
 #include <google/protobuf/unittest.pb.h>
@@ -563,34 +564,65 @@ TEST(WireFormatTest, ParseMessageSet) {
   EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
 }
 
-TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
+namespace {
+std::string BuildMessageSetItemStart() {
   string data;
   {
-    unittest::TestMessageSetExtension1 message;
-    message.set_i(123);
-    // Build a MessageSet manually with its message content put before its
-    // type_id.
     io::StringOutputStream output_stream(&data);
     io::CodedOutputStream coded_output(&output_stream);
     coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
+  }
+  return data;
+}
+std::string BuildMessageSetItemEnd() {
+  std::string data;
+  {
+    io::StringOutputStream output_stream(&data);
+    io::CodedOutputStream coded_output(&output_stream);
+    coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
+  }
+  return data;
+}
+std::string BuildMessageSetTestExtension1(int value = 123) {
+  std::string data;
+  {
+    unittest::TestMessageSetExtension1 message;
+    message.set_i(value);
+    io::StringOutputStream output_stream(&data);
+    io::CodedOutputStream coded_output(&output_stream);
     // Write the message content first.
     WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
                              WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
                              &coded_output);
     coded_output.WriteVarint32(message.ByteSize());
     message.SerializeWithCachedSizes(&coded_output);
-    // Write the type id.
-    uint32 type_id = message.GetDescriptor()->extension(0)->number();
+  }
+  return data;
+}
+std::string BuildMessageSetItemTypeId(int extension_number) {
+  std::string data;
+  {
+    io::StringOutputStream output_stream(&data);
+    io::CodedOutputStream coded_output(&output_stream);
     WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
-                                type_id, &coded_output);
-    coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
+                                extension_number, &coded_output);
   }
+  return data;
+}
+void ValidateTestMessageSet(const std::string& test_case,
+                            const std::string& data) {
+  SCOPED_TRACE(test_case);
   {
     proto2_wireformat_unittest::TestMessageSet message_set;
     ASSERT_TRUE(message_set.ParseFromString(data));
 
     EXPECT_EQ(123, message_set.GetExtension(
         unittest::TestMessageSetExtension1::message_set_extension).i());
+
+    // Make sure it does not contain anything else.
+    message_set.ClearExtension(
+        unittest::TestMessageSetExtension1::message_set_extension);
+    EXPECT_EQ(message_set.SerializeAsString(), "");
   }
   {
     // Test parse the message via Reflection.
@@ -603,6 +635,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
     EXPECT_EQ(123, message_set.GetExtension(
         unittest::TestMessageSetExtension1::message_set_extension).i());
   }
+  {
+    // Test parse the message via DynamicMessage.
+    DynamicMessageFactory factory;
+    std::unique_ptr<Message> msg(
+        factory
+            .GetPrototype(
+                proto2_wireformat_unittest::TestMessageSet::descriptor())
+            ->New());
+    msg->ParseFromString(data);
+    auto* reflection = msg->GetReflection();
+    std::vector<const FieldDescriptor*> fields;
+    reflection->ListFields(*msg, &fields);
+    ASSERT_EQ(fields.size(), 1);
+    const auto& sub = reflection->GetMessage(*msg, fields[0]);
+    reflection = sub.GetReflection();
+    EXPECT_EQ(123, reflection->GetInt32(
+                       sub, sub.GetDescriptor()->FindFieldByName("i")));
+  }
+}
+}  // namespace
+
+TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
+  std::string start = BuildMessageSetItemStart();
+  std::string end = BuildMessageSetItemEnd();
+  std::string id = BuildMessageSetItemTypeId(
+      unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
+  std::string message = BuildMessageSetTestExtension1();
+
+  ValidateTestMessageSet("id + message", start + id + message + end);
+  ValidateTestMessageSet("message + id", start + message + id + end);
+}
+
+TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
+  std::string start = BuildMessageSetItemStart();
+  std::string end = BuildMessageSetItemEnd();
+  std::string id = BuildMessageSetItemTypeId(
+      unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
+  std::string other_id = BuildMessageSetItemTypeId(123456);
+  std::string message = BuildMessageSetTestExtension1();
+  std::string other_message = BuildMessageSetTestExtension1(321);
+
+  // Double id
+  ValidateTestMessageSet("id + other_id + message",
+                         start + id + other_id + message + end);
+  ValidateTestMessageSet("id + message + other_id",
+                         start + id + message + other_id + end);
+  ValidateTestMessageSet("message + id + other_id",
+                         start + message + id + other_id + end);
+  // Double message
+  ValidateTestMessageSet("id + message + other_message",
+                         start + id + message + other_message + end);
+  ValidateTestMessageSet("message + id + other_message",
+                         start + message + id + other_message + end);
+  ValidateTestMessageSet("message + other_message + id",
+                         start + message + other_message + id + end);
 }
 
 TEST(WireFormatTest, ParseBrokenMessageSet) {