From: Mike Kruskal Date: Mon, 12 Sep 2022 21:39:23 +0000 (-0700) Subject: [PATCH] Sync from Piper @473817856 X-Git-Tag: archive/raspbian/3.6.1.3-2+rpi1+deb10u1^2~1 X-Git-Url: https://dgit.raspbian.org/?a=commitdiff_plain;h=d025354a1085b19c815c79077d8f53296ad4a69c;p=protobuf.git [PATCH] Sync from Piper @473817856 PROTOBUF_SYNC_PIPER Gbp-Pq: Name CVE-2022-1941.patch --- diff --git a/src/google/protobuf/extension_set_heavy.cc b/src/google/protobuf/extension_set_heavy.cc index a3c8416..fec0bfd 100644 --- a/src/google/protobuf/extension_set_heavy.cc +++ b/src/google/protobuf/extension_set_heavy.cc @@ -669,6 +669,8 @@ bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input, // required data message = 3; uint32 last_type_id = 0; + enum class State { kNoTag, kHasType, kHasPayload, kDone }; + State state = State::kNoTag; // If we see message data before the type_id, we'll append it to this so // we can parse it later. @@ -682,9 +684,12 @@ bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input, case WireFormatLite::kMessageSetTypeIdTag: { uint32 type_id; if (!input->ReadVarint32(&type_id)) return false; - last_type_id = type_id; - if (!message_data.empty()) { + if (state == State::kNoTag) { + last_type_id = type_id; + state = State::kHasType; + } else if (state == State::kHasPayload) { + last_type_id = type_id; // We saw some message data before the type_id. Have to parse it // now. io::CodedInputStream sub_input( @@ -696,22 +701,26 @@ bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input, return false; } message_data.clear(); + state = State::kDone; } break; } case WireFormatLite::kMessageSetMessageTag: { - if (last_type_id == 0) { + if (state != State::kHasType) { // We haven't seen a type_id yet. Append this data to message_data. string temp; uint32 length; if (!input->ReadVarint32(&length)) return false; if (!input->ReadString(&temp, length)) return false; - io::StringOutputStream output_stream(&message_data); - io::CodedOutputStream coded_output(&output_stream); - coded_output.WriteVarint32(length); - coded_output.WriteString(temp); + if (state == State::kNoTag) { + io::StringOutputStream output_stream(&message_data); + io::CodedOutputStream coded_output(&output_stream); + coded_output.WriteVarint32(length); + coded_output.WriteString(temp); + state = State::kHasPayload; + } } else { // Already saw type_id, so we can parse this directly. if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED, @@ -719,6 +728,7 @@ bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input, extension_finder, field_skipper)) { return false; } + state = State::kDone; } break; diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc index 3fdf84e..b153e16 100644 --- a/src/google/protobuf/wire_format.cc +++ b/src/google/protobuf/wire_format.cc @@ -725,10 +725,8 @@ bool WireFormat::ParseAndMergeMessageSetItem( // required data message = 3; uint32 last_type_id = 0; - - // Once we see a type_id, we'll look up the FieldDescriptor for the - // extension. - const FieldDescriptor* field = NULL; + enum class State { kNoTag, kHasType, kHasPayload, kDone }; + State state = State::kNoTag; // If we see message data before the type_id, we'll append it to this so // we can parse it later. @@ -742,12 +740,15 @@ bool WireFormat::ParseAndMergeMessageSetItem( case WireFormatLite::kMessageSetTypeIdTag: { uint32 type_id; if (!input->ReadVarint32(&type_id)) return false; - last_type_id = type_id; - field = message_reflection->FindKnownExtensionByNumber(type_id); - if (!message_data.empty()) { + if (state == State::kNoTag) { + last_type_id = type_id; + state = State::kHasType; + } else if (state == State::kHasPayload) { // We saw some message data before the type_id. Have to parse it // now. + last_type_id = type_id; + const FieldDescriptor* field = message_reflection->FindKnownExtensionByNumber(type_id); io::ArrayInputStream raw_input(message_data.data(), message_data.size()); io::CodedInputStream sub_input(&raw_input); @@ -756,13 +757,14 @@ bool WireFormat::ParseAndMergeMessageSetItem( return false; } message_data.clear(); + state = State::kDone; } break; } case WireFormatLite::kMessageSetMessageTag: { - if (last_type_id == 0) { + if (state == State::kNoTag) { // We haven't seen a type_id yet. Append this data to message_data. string temp; uint32 length; @@ -772,12 +774,17 @@ bool WireFormat::ParseAndMergeMessageSetItem( io::CodedOutputStream coded_output(&output_stream); coded_output.WriteVarint32(length); coded_output.WriteString(temp); - } else { + state = State::kHasPayload; + } else if (state == State::kHasType) { // Already saw type_id, so we can parse this directly. + const FieldDescriptor* field = message_reflection->FindKnownExtensionByNumber(last_type_id); if (!ParseAndMergeMessageSetField(last_type_id, field, message, input)) { return false; } + state = State::kDone; + } else { + if (!SkipField(input, tag, NULL)) return false; } break; diff --git a/src/google/protobuf/wire_format_unittest.cc b/src/google/protobuf/wire_format_unittest.cc index 736a128..64d3909 100644 --- a/src/google/protobuf/wire_format_unittest.cc +++ b/src/google/protobuf/wire_format_unittest.cc @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -563,34 +564,65 @@ TEST(WireFormatTest, ParseMessageSet) { EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString()); } -TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) { +namespace { +std::string BuildMessageSetItemStart() { string data; { - unittest::TestMessageSetExtension1 message; - message.set_i(123); - // Build a MessageSet manually with its message content put before its - // type_id. io::StringOutputStream output_stream(&data); io::CodedOutputStream coded_output(&output_stream); coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag); + } + return data; +} +std::string BuildMessageSetItemEnd() { + std::string data; + { + io::StringOutputStream output_stream(&data); + io::CodedOutputStream coded_output(&output_stream); + coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag); + } + return data; +} +std::string BuildMessageSetTestExtension1(int value = 123) { + std::string data; + { + unittest::TestMessageSetExtension1 message; + message.set_i(value); + io::StringOutputStream output_stream(&data); + io::CodedOutputStream coded_output(&output_stream); // Write the message content first. WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber, WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &coded_output); coded_output.WriteVarint32(message.ByteSize()); message.SerializeWithCachedSizes(&coded_output); - // Write the type id. - uint32 type_id = message.GetDescriptor()->extension(0)->number(); + } + return data; +} +std::string BuildMessageSetItemTypeId(int extension_number) { + std::string data; + { + io::StringOutputStream output_stream(&data); + io::CodedOutputStream coded_output(&output_stream); WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber, - type_id, &coded_output); - coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag); + extension_number, &coded_output); } + return data; +} +void ValidateTestMessageSet(const std::string& test_case, + const std::string& data) { + SCOPED_TRACE(test_case); { proto2_wireformat_unittest::TestMessageSet message_set; ASSERT_TRUE(message_set.ParseFromString(data)); EXPECT_EQ(123, message_set.GetExtension( unittest::TestMessageSetExtension1::message_set_extension).i()); + + // Make sure it does not contain anything else. + message_set.ClearExtension( + unittest::TestMessageSetExtension1::message_set_extension); + EXPECT_EQ(message_set.SerializeAsString(), ""); } { // Test parse the message via Reflection. @@ -603,6 +635,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) { EXPECT_EQ(123, message_set.GetExtension( unittest::TestMessageSetExtension1::message_set_extension).i()); } + { + // Test parse the message via DynamicMessage. + DynamicMessageFactory factory; + std::unique_ptr msg( + factory + .GetPrototype( + proto2_wireformat_unittest::TestMessageSet::descriptor()) + ->New()); + msg->ParseFromString(data); + auto* reflection = msg->GetReflection(); + std::vector fields; + reflection->ListFields(*msg, &fields); + ASSERT_EQ(fields.size(), 1); + const auto& sub = reflection->GetMessage(*msg, fields[0]); + reflection = sub.GetReflection(); + EXPECT_EQ(123, reflection->GetInt32( + sub, sub.GetDescriptor()->FindFieldByName("i"))); + } +} +} // namespace + +TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) { + std::string start = BuildMessageSetItemStart(); + std::string end = BuildMessageSetItemEnd(); + std::string id = BuildMessageSetItemTypeId( + unittest::TestMessageSetExtension1::descriptor()->extension(0)->number()); + std::string message = BuildMessageSetTestExtension1(); + + ValidateTestMessageSet("id + message", start + id + message + end); + ValidateTestMessageSet("message + id", start + message + id + end); +} + +TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) { + std::string start = BuildMessageSetItemStart(); + std::string end = BuildMessageSetItemEnd(); + std::string id = BuildMessageSetItemTypeId( + unittest::TestMessageSetExtension1::descriptor()->extension(0)->number()); + std::string other_id = BuildMessageSetItemTypeId(123456); + std::string message = BuildMessageSetTestExtension1(); + std::string other_message = BuildMessageSetTestExtension1(321); + + // Double id + ValidateTestMessageSet("id + other_id + message", + start + id + other_id + message + end); + ValidateTestMessageSet("id + message + other_id", + start + id + message + other_id + end); + ValidateTestMessageSet("message + id + other_id", + start + message + id + other_id + end); + // Double message + ValidateTestMessageSet("id + message + other_message", + start + id + message + other_message + end); + ValidateTestMessageSet("message + id + other_message", + start + message + id + other_message + end); + ValidateTestMessageSet("message + other_message + id", + start + message + other_message + id + end); } TEST(WireFormatTest, ParseBrokenMessageSet) {