Compare commits

...

3 commits

6 changed files with 61 additions and 230 deletions

View file

@ -91,14 +91,12 @@ pub fn start(r: crate::HostInterfaceResources, spawner: Spawner, crc: crc_engine
#[embassy_executor::task] #[embassy_executor::task]
async fn tx_task(mut tx: Tx, crc: crc_engine::CrcHandle) { async fn tx_task(mut tx: Tx, crc: crc_engine::CrcHandle) {
let test_message = Message { let test_message = TargetMessage {
inner: Some(Message_::Inner::Resp(Response { msg: Some(TargetMessage_::Msg::Test(TestResponse {
msg: Some(Response_::Msg::Test(TestResponse { f1: 4567,
f1: 4567, f2: String::from_str("7865").unwrap(),
f2: String::from_str("7865").unwrap(), f3: true,
f3: true, f4: Vec::from_slice(b"abcde").unwrap(),
f4: Vec::from_slice(b"abcde").unwrap(),
})),
})), })),
}; };
#[allow(clippy::cast_possible_truncation)] #[allow(clippy::cast_possible_truncation)]

View file

@ -438,156 +438,32 @@ pub mod test_ {
} }
} }
pub mod api_ { pub mod api_ {
pub mod Message_ { pub mod HostMessage_ {
#[derive(Debug, PartialEq, Clone)]
pub enum Inner {
Req(super::Request),
Resp(super::Response),
}
}
#[derive(Debug, Clone)]
pub struct Message {
pub r#inner: ::core::option::Option<Message_::Inner>,
}
impl ::core::default::Default for Message {
fn default() -> Self {
Self {
r#inner: ::core::default::Default::default(),
}
}
}
impl ::core::cmp::PartialEq for Message {
fn eq(&self, other: &Self) -> bool {
let mut ret = true;
ret &= (self.r#inner == other.r#inner);
ret
}
}
impl Message {}
impl ::micropb::MessageDecode for Message {
fn decode<IMPL_MICROPB_READ: ::micropb::PbRead>(
&mut self,
decoder: &mut ::micropb::PbDecoder<IMPL_MICROPB_READ>,
len: usize,
) -> Result<(), ::micropb::DecodeError<IMPL_MICROPB_READ::Error>> {
use ::micropb::{PbVec, PbMap, PbString, FieldDecode};
let before = decoder.bytes_read();
while decoder.bytes_read() - before < len {
let tag = decoder.decode_tag()?;
match tag.field_num() {
0 => return Err(::micropb::DecodeError::ZeroField),
1u32 => {
let mut_ref = loop {
if let ::core::option::Option::Some(variant) = &mut self
.r#inner
{
if let Message_::Inner::Req(variant) = &mut *variant {
break &mut *variant;
}
}
self.r#inner = ::core::option::Option::Some(
Message_::Inner::Req(::core::default::Default::default()),
);
};
mut_ref.decode_len_delimited(decoder)?;
}
2u32 => {
let mut_ref = loop {
if let ::core::option::Option::Some(variant) = &mut self
.r#inner
{
if let Message_::Inner::Resp(variant) = &mut *variant {
break &mut *variant;
}
}
self.r#inner = ::core::option::Option::Some(
Message_::Inner::Resp(::core::default::Default::default()),
);
};
mut_ref.decode_len_delimited(decoder)?;
}
_ => {
decoder.skip_wire_value(tag.wire_type())?;
}
}
}
Ok(())
}
}
impl ::micropb::MessageEncode for Message {
fn encode<IMPL_MICROPB_WRITE: ::micropb::PbWrite>(
&self,
encoder: &mut ::micropb::PbEncoder<IMPL_MICROPB_WRITE>,
) -> Result<(), IMPL_MICROPB_WRITE::Error> {
use ::micropb::{PbVec, PbMap, PbString, FieldEncode};
if let Some(oneof) = &self.r#inner {
match &*oneof {
Message_::Inner::Req(val_ref) => {
let val_ref = &*val_ref;
encoder.encode_varint32(10u32)?;
val_ref.encode_len_delimited(encoder)?;
}
Message_::Inner::Resp(val_ref) => {
let val_ref = &*val_ref;
encoder.encode_varint32(18u32)?;
val_ref.encode_len_delimited(encoder)?;
}
}
}
Ok(())
}
fn compute_size(&self) -> usize {
use ::micropb::{PbVec, PbMap, PbString, FieldEncode};
let mut size = 0;
if let Some(oneof) = &self.r#inner {
match &*oneof {
Message_::Inner::Req(val_ref) => {
let val_ref = &*val_ref;
size
+= 1usize
+ ::micropb::size::sizeof_len_record(
val_ref.compute_size(),
);
}
Message_::Inner::Resp(val_ref) => {
let val_ref = &*val_ref;
size
+= 1usize
+ ::micropb::size::sizeof_len_record(
val_ref.compute_size(),
);
}
}
}
size
}
}
pub mod Request_ {
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub enum Msg { pub enum Msg {
Echo(super::super::echo_::EchoRequest), Echo(super::super::echo_::EchoRequest),
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Request { pub struct HostMessage {
pub r#msg: ::core::option::Option<Request_::Msg>, pub r#msg: ::core::option::Option<HostMessage_::Msg>,
} }
impl ::core::default::Default for Request { impl ::core::default::Default for HostMessage {
fn default() -> Self { fn default() -> Self {
Self { Self {
r#msg: ::core::default::Default::default(), r#msg: ::core::default::Default::default(),
} }
} }
} }
impl ::core::cmp::PartialEq for Request { impl ::core::cmp::PartialEq for HostMessage {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
let mut ret = true; let mut ret = true;
ret &= (self.r#msg == other.r#msg); ret &= (self.r#msg == other.r#msg);
ret ret
} }
} }
impl Request {} impl HostMessage {}
impl ::micropb::MessageDecode for Request { impl ::micropb::MessageDecode for HostMessage {
fn decode<IMPL_MICROPB_READ: ::micropb::PbRead>( fn decode<IMPL_MICROPB_READ: ::micropb::PbRead>(
&mut self, &mut self,
decoder: &mut ::micropb::PbDecoder<IMPL_MICROPB_READ>, decoder: &mut ::micropb::PbDecoder<IMPL_MICROPB_READ>,
@ -604,12 +480,12 @@ pub mod api_ {
if let ::core::option::Option::Some(variant) = &mut self if let ::core::option::Option::Some(variant) = &mut self
.r#msg .r#msg
{ {
if let Request_::Msg::Echo(variant) = &mut *variant { if let HostMessage_::Msg::Echo(variant) = &mut *variant {
break &mut *variant; break &mut *variant;
} }
} }
self.r#msg = ::core::option::Option::Some( self.r#msg = ::core::option::Option::Some(
Request_::Msg::Echo(::core::default::Default::default()), HostMessage_::Msg::Echo(::core::default::Default::default()),
); );
}; };
mut_ref.decode_len_delimited(decoder)?; mut_ref.decode_len_delimited(decoder)?;
@ -622,7 +498,7 @@ pub mod api_ {
Ok(()) Ok(())
} }
} }
impl ::micropb::MessageEncode for Request { impl ::micropb::MessageEncode for HostMessage {
fn encode<IMPL_MICROPB_WRITE: ::micropb::PbWrite>( fn encode<IMPL_MICROPB_WRITE: ::micropb::PbWrite>(
&self, &self,
encoder: &mut ::micropb::PbEncoder<IMPL_MICROPB_WRITE>, encoder: &mut ::micropb::PbEncoder<IMPL_MICROPB_WRITE>,
@ -630,7 +506,7 @@ pub mod api_ {
use ::micropb::{PbVec, PbMap, PbString, FieldEncode}; use ::micropb::{PbVec, PbMap, PbString, FieldEncode};
if let Some(oneof) = &self.r#msg { if let Some(oneof) = &self.r#msg {
match &*oneof { match &*oneof {
Request_::Msg::Echo(val_ref) => { HostMessage_::Msg::Echo(val_ref) => {
let val_ref = &*val_ref; let val_ref = &*val_ref;
encoder.encode_varint32(10u32)?; encoder.encode_varint32(10u32)?;
val_ref.encode_len_delimited(encoder)?; val_ref.encode_len_delimited(encoder)?;
@ -644,7 +520,7 @@ pub mod api_ {
let mut size = 0; let mut size = 0;
if let Some(oneof) = &self.r#msg { if let Some(oneof) = &self.r#msg {
match &*oneof { match &*oneof {
Request_::Msg::Echo(val_ref) => { HostMessage_::Msg::Echo(val_ref) => {
let val_ref = &*val_ref; let val_ref = &*val_ref;
size size
+= 1usize += 1usize
@ -657,7 +533,7 @@ pub mod api_ {
size size
} }
} }
pub mod Response_ { pub mod TargetMessage_ {
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub enum Msg { pub enum Msg {
Echo(super::super::echo_::EchoResponse), Echo(super::super::echo_::EchoResponse),
@ -665,25 +541,25 @@ pub mod api_ {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Response { pub struct TargetMessage {
pub r#msg: ::core::option::Option<Response_::Msg>, pub r#msg: ::core::option::Option<TargetMessage_::Msg>,
} }
impl ::core::default::Default for Response { impl ::core::default::Default for TargetMessage {
fn default() -> Self { fn default() -> Self {
Self { Self {
r#msg: ::core::default::Default::default(), r#msg: ::core::default::Default::default(),
} }
} }
} }
impl ::core::cmp::PartialEq for Response { impl ::core::cmp::PartialEq for TargetMessage {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
let mut ret = true; let mut ret = true;
ret &= (self.r#msg == other.r#msg); ret &= (self.r#msg == other.r#msg);
ret ret
} }
} }
impl Response {} impl TargetMessage {}
impl ::micropb::MessageDecode for Response { impl ::micropb::MessageDecode for TargetMessage {
fn decode<IMPL_MICROPB_READ: ::micropb::PbRead>( fn decode<IMPL_MICROPB_READ: ::micropb::PbRead>(
&mut self, &mut self,
decoder: &mut ::micropb::PbDecoder<IMPL_MICROPB_READ>, decoder: &mut ::micropb::PbDecoder<IMPL_MICROPB_READ>,
@ -700,12 +576,14 @@ pub mod api_ {
if let ::core::option::Option::Some(variant) = &mut self if let ::core::option::Option::Some(variant) = &mut self
.r#msg .r#msg
{ {
if let Response_::Msg::Echo(variant) = &mut *variant { if let TargetMessage_::Msg::Echo(variant) = &mut *variant {
break &mut *variant; break &mut *variant;
} }
} }
self.r#msg = ::core::option::Option::Some( self.r#msg = ::core::option::Option::Some(
Response_::Msg::Echo(::core::default::Default::default()), TargetMessage_::Msg::Echo(
::core::default::Default::default(),
),
); );
}; };
mut_ref.decode_len_delimited(decoder)?; mut_ref.decode_len_delimited(decoder)?;
@ -715,12 +593,14 @@ pub mod api_ {
if let ::core::option::Option::Some(variant) = &mut self if let ::core::option::Option::Some(variant) = &mut self
.r#msg .r#msg
{ {
if let Response_::Msg::Test(variant) = &mut *variant { if let TargetMessage_::Msg::Test(variant) = &mut *variant {
break &mut *variant; break &mut *variant;
} }
} }
self.r#msg = ::core::option::Option::Some( self.r#msg = ::core::option::Option::Some(
Response_::Msg::Test(::core::default::Default::default()), TargetMessage_::Msg::Test(
::core::default::Default::default(),
),
); );
}; };
mut_ref.decode_len_delimited(decoder)?; mut_ref.decode_len_delimited(decoder)?;
@ -733,7 +613,7 @@ pub mod api_ {
Ok(()) Ok(())
} }
} }
impl ::micropb::MessageEncode for Response { impl ::micropb::MessageEncode for TargetMessage {
fn encode<IMPL_MICROPB_WRITE: ::micropb::PbWrite>( fn encode<IMPL_MICROPB_WRITE: ::micropb::PbWrite>(
&self, &self,
encoder: &mut ::micropb::PbEncoder<IMPL_MICROPB_WRITE>, encoder: &mut ::micropb::PbEncoder<IMPL_MICROPB_WRITE>,
@ -741,12 +621,12 @@ pub mod api_ {
use ::micropb::{PbVec, PbMap, PbString, FieldEncode}; use ::micropb::{PbVec, PbMap, PbString, FieldEncode};
if let Some(oneof) = &self.r#msg { if let Some(oneof) = &self.r#msg {
match &*oneof { match &*oneof {
Response_::Msg::Echo(val_ref) => { TargetMessage_::Msg::Echo(val_ref) => {
let val_ref = &*val_ref; let val_ref = &*val_ref;
encoder.encode_varint32(10u32)?; encoder.encode_varint32(10u32)?;
val_ref.encode_len_delimited(encoder)?; val_ref.encode_len_delimited(encoder)?;
} }
Response_::Msg::Test(val_ref) => { TargetMessage_::Msg::Test(val_ref) => {
let val_ref = &*val_ref; let val_ref = &*val_ref;
encoder.encode_varint32(18u32)?; encoder.encode_varint32(18u32)?;
val_ref.encode_len_delimited(encoder)?; val_ref.encode_len_delimited(encoder)?;
@ -760,7 +640,7 @@ pub mod api_ {
let mut size = 0; let mut size = 0;
if let Some(oneof) = &self.r#msg { if let Some(oneof) = &self.r#msg {
match &*oneof { match &*oneof {
Response_::Msg::Echo(val_ref) => { TargetMessage_::Msg::Echo(val_ref) => {
let val_ref = &*val_ref; let val_ref = &*val_ref;
size size
+= 1usize += 1usize
@ -768,7 +648,7 @@ pub mod api_ {
val_ref.compute_size(), val_ref.compute_size(),
); );
} }
Response_::Msg::Test(val_ref) => { TargetMessage_::Msg::Test(val_ref) => {
let val_ref = &*val_ref; let val_ref = &*val_ref;
size size
+= 1usize += 1usize

View file

@ -5,20 +5,13 @@ package api;
import "echo.proto"; import "echo.proto";
import "test.proto"; import "test.proto";
message Message { message HostMessage {
oneof inner {
Request req = 1;
Response resp = 2;
}
}
message Request {
oneof msg { oneof msg {
echo.EchoRequest echo = 1; echo.EchoRequest echo = 1;
} }
} }
message Response { message TargetMessage {
oneof msg { oneof msg {
echo.EchoResponse echo = 1; echo.EchoResponse echo = 1;
test.TestResponse test = 2; test.TestResponse test = 2;

View file

@ -15,17 +15,15 @@ import echo_pb2 as echo__pb2
import test_pb2 as test__pb2 import test_pb2 as test__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tapi.proto\x12\x03\x61pi\x1a\necho.proto\x1a\ntest.proto\"N\n\x07Message\x12\x1b\n\x03req\x18\x01 \x01(\x0b\x32\x0c.api.RequestH\x00\x12\x1d\n\x04resp\x18\x02 \x01(\x0b\x32\r.api.ResponseH\x00\x42\x07\n\x05inner\"3\n\x07Request\x12!\n\x04\x65\x63ho\x18\x01 \x01(\x0b\x32\x11.echo.EchoRequestH\x00\x42\x05\n\x03msg\"Y\n\x08Response\x12\"\n\x04\x65\x63ho\x18\x01 \x01(\x0b\x32\x12.echo.EchoResponseH\x00\x12\"\n\x04test\x18\x02 \x01(\x0b\x32\x12.test.TestResponseH\x00\x42\x05\n\x03msgb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tapi.proto\x12\x03\x61pi\x1a\necho.proto\x1a\ntest.proto\"7\n\x0bHostMessage\x12!\n\x04\x65\x63ho\x18\x01 \x01(\x0b\x32\x11.echo.EchoRequestH\x00\x42\x05\n\x03msg\"^\n\rTargetMessage\x12\"\n\x04\x65\x63ho\x18\x01 \x01(\x0b\x32\x12.echo.EchoResponseH\x00\x12\"\n\x04test\x18\x02 \x01(\x0b\x32\x12.test.TestResponseH\x00\x42\x05\n\x03msgb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'api_pb2', globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'api_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False: if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None DESCRIPTOR._options = None
_MESSAGE._serialized_start=42 _HOSTMESSAGE._serialized_start=42
_MESSAGE._serialized_end=120 _HOSTMESSAGE._serialized_end=97
_REQUEST._serialized_start=122 _TARGETMESSAGE._serialized_start=99
_REQUEST._serialized_end=173 _TARGETMESSAGE._serialized_end=193
_RESPONSE._serialized_start=175
_RESPONSE._serialized_end=264
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View file

@ -6,21 +6,13 @@ from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Opti
DESCRIPTOR: _descriptor.FileDescriptor DESCRIPTOR: _descriptor.FileDescriptor
class Message(_message.Message): class HostMessage(_message.Message):
__slots__ = ["req", "resp"]
REQ_FIELD_NUMBER: _ClassVar[int]
RESP_FIELD_NUMBER: _ClassVar[int]
req: Request
resp: Response
def __init__(self, req: _Optional[_Union[Request, _Mapping]] = ..., resp: _Optional[_Union[Response, _Mapping]] = ...) -> None: ...
class Request(_message.Message):
__slots__ = ["echo"] __slots__ = ["echo"]
ECHO_FIELD_NUMBER: _ClassVar[int] ECHO_FIELD_NUMBER: _ClassVar[int]
echo: _echo_pb2.EchoRequest echo: _echo_pb2.EchoRequest
def __init__(self, echo: _Optional[_Union[_echo_pb2.EchoRequest, _Mapping]] = ...) -> None: ... def __init__(self, echo: _Optional[_Union[_echo_pb2.EchoRequest, _Mapping]] = ...) -> None: ...
class Response(_message.Message): class TargetMessage(_message.Message):
__slots__ = ["echo", "test"] __slots__ = ["echo", "test"]
ECHO_FIELD_NUMBER: _ClassVar[int] ECHO_FIELD_NUMBER: _ClassVar[int]
TEST_FIELD_NUMBER: _ClassVar[int] TEST_FIELD_NUMBER: _ClassVar[int]

View file

@ -28,7 +28,7 @@ from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent
from prompt_toolkit.patch_stdout import patch_stdout from prompt_toolkit.patch_stdout import patch_stdout
from prompt_toolkit.shortcuts import PromptSession from prompt_toolkit.shortcuts import PromptSession
from api_pb2 import Message, Request, Response from api_pb2 import HostMessage, TargetMessage
SYNC_BYTE = b"\xfc" SYNC_BYTE = b"\xfc"
@ -37,10 +37,8 @@ CRC_SIZE = 2
uint16_le = struct.Struct("<H") uint16_le = struct.Struct("<H")
MessageQueue = asyncio.Queue[Message] incoming_messages = asyncio.Queue[TargetMessage]()
outgoing_messages = asyncio.Queue[HostMessage]()
incoming_messages = asyncio.Queue[Message]()
outgoing_messages = asyncio.Queue[Message]()
echo_response_queue = asyncio.Queue[int](maxsize=1) echo_response_queue = asyncio.Queue[int](maxsize=1)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -152,7 +150,7 @@ async def proto_listener(stream: asyncio.StreamReader) -> None:
input_buf.clear() input_buf.clear()
msg = Message() msg = TargetMessage()
try: try:
msg.ParseFromString(message_bytes) msg.ParseFromString(message_bytes)
@ -214,18 +212,10 @@ async def message_handler() -> None:
msg = await incoming_messages.get() msg = await incoming_messages.get()
try: try:
if not msg.HasField("inner"): if not msg.HasField("msg"):
invalid_message("missing 'inner' field") invalid_message("missing 'msg' field")
match msg.WhichOneof("inner"): await handle_target_message(msg)
case "req":
await handle_request(msg.req)
case "resp":
await handle_response(msg.resp)
case _:
invalid_message("unknown message type")
except InvalidMessageError as ex: except InvalidMessageError as ex:
logger.error(ex) # noqa: TRY400 logger.error(ex) # noqa: TRY400
@ -236,30 +226,10 @@ async def message_handler() -> None:
logger.info("Message Handler task cancelled.") logger.info("Message Handler task cancelled.")
async def handle_request(req: Request) -> None: async def handle_target_message(tgt: TargetMessage) -> None:
if not req.HasField("msg"): match tgt.WhichOneof("msg"):
invalid_message("request: missing 'msg' field")
match req.WhichOneof("msg"):
case "echo": case "echo":
if not req.echo.HasField("data"): await echo_response_queue.put(tgt.echo.data)
invalid_message("request: echo: missing 'data' field")
msg = Message()
msg.resp.echo.data = req.echo.data
await outgoing_messages.put(msg)
case _:
invalid_message("request: unknown type")
async def handle_response(resp: Response) -> None:
if not resp.HasField("msg"):
invalid_message("response: missing 'msg' field")
match resp.WhichOneof("msg"):
case "echo":
await echo_response_queue.put(resp.echo.data)
case "test": case "test":
pass pass
@ -271,8 +241,8 @@ async def handle_response(resp: Response) -> None:
async def message_sender() -> None: async def message_sender() -> None:
try: try:
while True: while True:
msg = Message() msg = HostMessage()
msg.req.echo.data = random.randint(0, 2**22) msg.echo.data = random.randint(0, 2**22)
await outgoing_messages.put(msg) await outgoing_messages.put(msg)
@ -283,9 +253,9 @@ async def message_sender() -> None:
logger.error("Timeout waiting for echo response") # noqa: TRY400 logger.error("Timeout waiting for echo response") # noqa: TRY400
else: else:
if response != msg.req.echo.data: if response != msg.echo.data:
logger.error( logger.error(
"Incorrect echo response: expected %d - got %d", msg.req.echo.data, response "Incorrect echo response: expected %d - got %d", msg.echo.data, response
) )
await asyncio.sleep(5) await asyncio.sleep(5)