Compare commits

..

No commits in common. "c472fccaf765ea4efdcc28e903d0e5888902f859" and "55d0dace03fe9c8f5aaa250560e89586ebdd83e7" have entirely different histories.

6 changed files with 230 additions and 61 deletions

View file

@ -91,13 +91,15 @@ pub fn start(r: crate::HostInterfaceResources, spawner: Spawner, crc: crc_engine
#[embassy_executor::task] #[embassy_executor::task]
async fn tx_task(mut tx: Tx, crc: crc_engine::CrcHandle) { async fn tx_task(mut tx: Tx, crc: crc_engine::CrcHandle) {
let test_message = TargetMessage { let test_message = Message {
msg: Some(TargetMessage_::Msg::Test(TestResponse { inner: Some(Message_::Inner::Resp(Response {
msg: Some(Response_::Msg::Test(TestResponse {
f1: 4567, f1: 4567,
f2: String::from_str("7865").unwrap(), f2: String::from_str("7865").unwrap(),
f3: true, f3: true,
f4: Vec::from_slice(b"abcde").unwrap(), f4: Vec::from_slice(b"abcde").unwrap(),
})), })),
})),
}; };
#[allow(clippy::cast_possible_truncation)] #[allow(clippy::cast_possible_truncation)]
let encoded_size = (test_message.compute_size() as u16).to_le_bytes(); let encoded_size = (test_message.compute_size() as u16).to_le_bytes();

View file

@ -438,32 +438,156 @@ pub mod test_ {
} }
} }
pub mod api_ { pub mod api_ {
pub mod HostMessage_ { pub mod Message_ {
#[derive(Debug, PartialEq, Clone)]
pub enum Inner {
Req(super::Request),
Resp(super::Response),
}
}
#[derive(Debug, Clone)]
pub struct Message {
pub r#inner: ::core::option::Option<Message_::Inner>,
}
impl ::core::default::Default for Message {
fn default() -> Self {
Self {
r#inner: ::core::default::Default::default(),
}
}
}
impl ::core::cmp::PartialEq for Message {
fn eq(&self, other: &Self) -> bool {
let mut ret = true;
ret &= (self.r#inner == other.r#inner);
ret
}
}
impl Message {}
impl ::micropb::MessageDecode for Message {
fn decode<IMPL_MICROPB_READ: ::micropb::PbRead>(
&mut self,
decoder: &mut ::micropb::PbDecoder<IMPL_MICROPB_READ>,
len: usize,
) -> Result<(), ::micropb::DecodeError<IMPL_MICROPB_READ::Error>> {
use ::micropb::{PbVec, PbMap, PbString, FieldDecode};
let before = decoder.bytes_read();
while decoder.bytes_read() - before < len {
let tag = decoder.decode_tag()?;
match tag.field_num() {
0 => return Err(::micropb::DecodeError::ZeroField),
1u32 => {
let mut_ref = loop {
if let ::core::option::Option::Some(variant) = &mut self
.r#inner
{
if let Message_::Inner::Req(variant) = &mut *variant {
break &mut *variant;
}
}
self.r#inner = ::core::option::Option::Some(
Message_::Inner::Req(::core::default::Default::default()),
);
};
mut_ref.decode_len_delimited(decoder)?;
}
2u32 => {
let mut_ref = loop {
if let ::core::option::Option::Some(variant) = &mut self
.r#inner
{
if let Message_::Inner::Resp(variant) = &mut *variant {
break &mut *variant;
}
}
self.r#inner = ::core::option::Option::Some(
Message_::Inner::Resp(::core::default::Default::default()),
);
};
mut_ref.decode_len_delimited(decoder)?;
}
_ => {
decoder.skip_wire_value(tag.wire_type())?;
}
}
}
Ok(())
}
}
impl ::micropb::MessageEncode for Message {
fn encode<IMPL_MICROPB_WRITE: ::micropb::PbWrite>(
&self,
encoder: &mut ::micropb::PbEncoder<IMPL_MICROPB_WRITE>,
) -> Result<(), IMPL_MICROPB_WRITE::Error> {
use ::micropb::{PbVec, PbMap, PbString, FieldEncode};
if let Some(oneof) = &self.r#inner {
match &*oneof {
Message_::Inner::Req(val_ref) => {
let val_ref = &*val_ref;
encoder.encode_varint32(10u32)?;
val_ref.encode_len_delimited(encoder)?;
}
Message_::Inner::Resp(val_ref) => {
let val_ref = &*val_ref;
encoder.encode_varint32(18u32)?;
val_ref.encode_len_delimited(encoder)?;
}
}
}
Ok(())
}
fn compute_size(&self) -> usize {
use ::micropb::{PbVec, PbMap, PbString, FieldEncode};
let mut size = 0;
if let Some(oneof) = &self.r#inner {
match &*oneof {
Message_::Inner::Req(val_ref) => {
let val_ref = &*val_ref;
size
+= 1usize
+ ::micropb::size::sizeof_len_record(
val_ref.compute_size(),
);
}
Message_::Inner::Resp(val_ref) => {
let val_ref = &*val_ref;
size
+= 1usize
+ ::micropb::size::sizeof_len_record(
val_ref.compute_size(),
);
}
}
}
size
}
}
pub mod Request_ {
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub enum Msg { pub enum Msg {
Echo(super::super::echo_::EchoRequest), Echo(super::super::echo_::EchoRequest),
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct HostMessage { pub struct Request {
pub r#msg: ::core::option::Option<HostMessage_::Msg>, pub r#msg: ::core::option::Option<Request_::Msg>,
} }
impl ::core::default::Default for HostMessage { impl ::core::default::Default for Request {
fn default() -> Self { fn default() -> Self {
Self { Self {
r#msg: ::core::default::Default::default(), r#msg: ::core::default::Default::default(),
} }
} }
} }
impl ::core::cmp::PartialEq for HostMessage { impl ::core::cmp::PartialEq for Request {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
let mut ret = true; let mut ret = true;
ret &= (self.r#msg == other.r#msg); ret &= (self.r#msg == other.r#msg);
ret ret
} }
} }
impl HostMessage {} impl Request {}
impl ::micropb::MessageDecode for HostMessage { impl ::micropb::MessageDecode for Request {
fn decode<IMPL_MICROPB_READ: ::micropb::PbRead>( fn decode<IMPL_MICROPB_READ: ::micropb::PbRead>(
&mut self, &mut self,
decoder: &mut ::micropb::PbDecoder<IMPL_MICROPB_READ>, decoder: &mut ::micropb::PbDecoder<IMPL_MICROPB_READ>,
@ -480,12 +604,12 @@ pub mod api_ {
if let ::core::option::Option::Some(variant) = &mut self if let ::core::option::Option::Some(variant) = &mut self
.r#msg .r#msg
{ {
if let HostMessage_::Msg::Echo(variant) = &mut *variant { if let Request_::Msg::Echo(variant) = &mut *variant {
break &mut *variant; break &mut *variant;
} }
} }
self.r#msg = ::core::option::Option::Some( self.r#msg = ::core::option::Option::Some(
HostMessage_::Msg::Echo(::core::default::Default::default()), Request_::Msg::Echo(::core::default::Default::default()),
); );
}; };
mut_ref.decode_len_delimited(decoder)?; mut_ref.decode_len_delimited(decoder)?;
@ -498,7 +622,7 @@ pub mod api_ {
Ok(()) Ok(())
} }
} }
impl ::micropb::MessageEncode for HostMessage { impl ::micropb::MessageEncode for Request {
fn encode<IMPL_MICROPB_WRITE: ::micropb::PbWrite>( fn encode<IMPL_MICROPB_WRITE: ::micropb::PbWrite>(
&self, &self,
encoder: &mut ::micropb::PbEncoder<IMPL_MICROPB_WRITE>, encoder: &mut ::micropb::PbEncoder<IMPL_MICROPB_WRITE>,
@ -506,7 +630,7 @@ pub mod api_ {
use ::micropb::{PbVec, PbMap, PbString, FieldEncode}; use ::micropb::{PbVec, PbMap, PbString, FieldEncode};
if let Some(oneof) = &self.r#msg { if let Some(oneof) = &self.r#msg {
match &*oneof { match &*oneof {
HostMessage_::Msg::Echo(val_ref) => { Request_::Msg::Echo(val_ref) => {
let val_ref = &*val_ref; let val_ref = &*val_ref;
encoder.encode_varint32(10u32)?; encoder.encode_varint32(10u32)?;
val_ref.encode_len_delimited(encoder)?; val_ref.encode_len_delimited(encoder)?;
@ -520,7 +644,7 @@ pub mod api_ {
let mut size = 0; let mut size = 0;
if let Some(oneof) = &self.r#msg { if let Some(oneof) = &self.r#msg {
match &*oneof { match &*oneof {
HostMessage_::Msg::Echo(val_ref) => { Request_::Msg::Echo(val_ref) => {
let val_ref = &*val_ref; let val_ref = &*val_ref;
size size
+= 1usize += 1usize
@ -533,7 +657,7 @@ pub mod api_ {
size size
} }
} }
pub mod TargetMessage_ { pub mod Response_ {
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub enum Msg { pub enum Msg {
Echo(super::super::echo_::EchoResponse), Echo(super::super::echo_::EchoResponse),
@ -541,25 +665,25 @@ pub mod api_ {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct TargetMessage { pub struct Response {
pub r#msg: ::core::option::Option<TargetMessage_::Msg>, pub r#msg: ::core::option::Option<Response_::Msg>,
} }
impl ::core::default::Default for TargetMessage { impl ::core::default::Default for Response {
fn default() -> Self { fn default() -> Self {
Self { Self {
r#msg: ::core::default::Default::default(), r#msg: ::core::default::Default::default(),
} }
} }
} }
impl ::core::cmp::PartialEq for TargetMessage { impl ::core::cmp::PartialEq for Response {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
let mut ret = true; let mut ret = true;
ret &= (self.r#msg == other.r#msg); ret &= (self.r#msg == other.r#msg);
ret ret
} }
} }
impl TargetMessage {} impl Response {}
impl ::micropb::MessageDecode for TargetMessage { impl ::micropb::MessageDecode for Response {
fn decode<IMPL_MICROPB_READ: ::micropb::PbRead>( fn decode<IMPL_MICROPB_READ: ::micropb::PbRead>(
&mut self, &mut self,
decoder: &mut ::micropb::PbDecoder<IMPL_MICROPB_READ>, decoder: &mut ::micropb::PbDecoder<IMPL_MICROPB_READ>,
@ -576,14 +700,12 @@ pub mod api_ {
if let ::core::option::Option::Some(variant) = &mut self if let ::core::option::Option::Some(variant) = &mut self
.r#msg .r#msg
{ {
if let TargetMessage_::Msg::Echo(variant) = &mut *variant { if let Response_::Msg::Echo(variant) = &mut *variant {
break &mut *variant; break &mut *variant;
} }
} }
self.r#msg = ::core::option::Option::Some( self.r#msg = ::core::option::Option::Some(
TargetMessage_::Msg::Echo( Response_::Msg::Echo(::core::default::Default::default()),
::core::default::Default::default(),
),
); );
}; };
mut_ref.decode_len_delimited(decoder)?; mut_ref.decode_len_delimited(decoder)?;
@ -593,14 +715,12 @@ pub mod api_ {
if let ::core::option::Option::Some(variant) = &mut self if let ::core::option::Option::Some(variant) = &mut self
.r#msg .r#msg
{ {
if let TargetMessage_::Msg::Test(variant) = &mut *variant { if let Response_::Msg::Test(variant) = &mut *variant {
break &mut *variant; break &mut *variant;
} }
} }
self.r#msg = ::core::option::Option::Some( self.r#msg = ::core::option::Option::Some(
TargetMessage_::Msg::Test( Response_::Msg::Test(::core::default::Default::default()),
::core::default::Default::default(),
),
); );
}; };
mut_ref.decode_len_delimited(decoder)?; mut_ref.decode_len_delimited(decoder)?;
@ -613,7 +733,7 @@ pub mod api_ {
Ok(()) Ok(())
} }
} }
impl ::micropb::MessageEncode for TargetMessage { impl ::micropb::MessageEncode for Response {
fn encode<IMPL_MICROPB_WRITE: ::micropb::PbWrite>( fn encode<IMPL_MICROPB_WRITE: ::micropb::PbWrite>(
&self, &self,
encoder: &mut ::micropb::PbEncoder<IMPL_MICROPB_WRITE>, encoder: &mut ::micropb::PbEncoder<IMPL_MICROPB_WRITE>,
@ -621,12 +741,12 @@ pub mod api_ {
use ::micropb::{PbVec, PbMap, PbString, FieldEncode}; use ::micropb::{PbVec, PbMap, PbString, FieldEncode};
if let Some(oneof) = &self.r#msg { if let Some(oneof) = &self.r#msg {
match &*oneof { match &*oneof {
TargetMessage_::Msg::Echo(val_ref) => { Response_::Msg::Echo(val_ref) => {
let val_ref = &*val_ref; let val_ref = &*val_ref;
encoder.encode_varint32(10u32)?; encoder.encode_varint32(10u32)?;
val_ref.encode_len_delimited(encoder)?; val_ref.encode_len_delimited(encoder)?;
} }
TargetMessage_::Msg::Test(val_ref) => { Response_::Msg::Test(val_ref) => {
let val_ref = &*val_ref; let val_ref = &*val_ref;
encoder.encode_varint32(18u32)?; encoder.encode_varint32(18u32)?;
val_ref.encode_len_delimited(encoder)?; val_ref.encode_len_delimited(encoder)?;
@ -640,7 +760,7 @@ pub mod api_ {
let mut size = 0; let mut size = 0;
if let Some(oneof) = &self.r#msg { if let Some(oneof) = &self.r#msg {
match &*oneof { match &*oneof {
TargetMessage_::Msg::Echo(val_ref) => { Response_::Msg::Echo(val_ref) => {
let val_ref = &*val_ref; let val_ref = &*val_ref;
size size
+= 1usize += 1usize
@ -648,7 +768,7 @@ pub mod api_ {
val_ref.compute_size(), val_ref.compute_size(),
); );
} }
TargetMessage_::Msg::Test(val_ref) => { Response_::Msg::Test(val_ref) => {
let val_ref = &*val_ref; let val_ref = &*val_ref;
size size
+= 1usize += 1usize

View file

@ -5,13 +5,20 @@ package api;
import "echo.proto"; import "echo.proto";
import "test.proto"; import "test.proto";
message HostMessage { message Message {
oneof inner {
Request req = 1;
Response resp = 2;
}
}
message Request {
oneof msg { oneof msg {
echo.EchoRequest echo = 1; echo.EchoRequest echo = 1;
} }
} }
message TargetMessage { message Response {
oneof msg { oneof msg {
echo.EchoResponse echo = 1; echo.EchoResponse echo = 1;
test.TestResponse test = 2; test.TestResponse test = 2;

View file

@ -15,15 +15,17 @@ import echo_pb2 as echo__pb2
import test_pb2 as test__pb2 import test_pb2 as test__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tapi.proto\x12\x03\x61pi\x1a\necho.proto\x1a\ntest.proto\"7\n\x0bHostMessage\x12!\n\x04\x65\x63ho\x18\x01 \x01(\x0b\x32\x11.echo.EchoRequestH\x00\x42\x05\n\x03msg\"^\n\rTargetMessage\x12\"\n\x04\x65\x63ho\x18\x01 \x01(\x0b\x32\x12.echo.EchoResponseH\x00\x12\"\n\x04test\x18\x02 \x01(\x0b\x32\x12.test.TestResponseH\x00\x42\x05\n\x03msgb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tapi.proto\x12\x03\x61pi\x1a\necho.proto\x1a\ntest.proto\"N\n\x07Message\x12\x1b\n\x03req\x18\x01 \x01(\x0b\x32\x0c.api.RequestH\x00\x12\x1d\n\x04resp\x18\x02 \x01(\x0b\x32\r.api.ResponseH\x00\x42\x07\n\x05inner\"3\n\x07Request\x12!\n\x04\x65\x63ho\x18\x01 \x01(\x0b\x32\x11.echo.EchoRequestH\x00\x42\x05\n\x03msg\"Y\n\x08Response\x12\"\n\x04\x65\x63ho\x18\x01 \x01(\x0b\x32\x12.echo.EchoResponseH\x00\x12\"\n\x04test\x18\x02 \x01(\x0b\x32\x12.test.TestResponseH\x00\x42\x05\n\x03msgb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'api_pb2', globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'api_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False: if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None DESCRIPTOR._options = None
_HOSTMESSAGE._serialized_start=42 _MESSAGE._serialized_start=42
_HOSTMESSAGE._serialized_end=97 _MESSAGE._serialized_end=120
_TARGETMESSAGE._serialized_start=99 _REQUEST._serialized_start=122
_TARGETMESSAGE._serialized_end=193 _REQUEST._serialized_end=173
_RESPONSE._serialized_start=175
_RESPONSE._serialized_end=264
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View file

@ -6,13 +6,21 @@ from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Opti
DESCRIPTOR: _descriptor.FileDescriptor DESCRIPTOR: _descriptor.FileDescriptor
class HostMessage(_message.Message): class Message(_message.Message):
__slots__ = ["req", "resp"]
REQ_FIELD_NUMBER: _ClassVar[int]
RESP_FIELD_NUMBER: _ClassVar[int]
req: Request
resp: Response
def __init__(self, req: _Optional[_Union[Request, _Mapping]] = ..., resp: _Optional[_Union[Response, _Mapping]] = ...) -> None: ...
class Request(_message.Message):
__slots__ = ["echo"] __slots__ = ["echo"]
ECHO_FIELD_NUMBER: _ClassVar[int] ECHO_FIELD_NUMBER: _ClassVar[int]
echo: _echo_pb2.EchoRequest echo: _echo_pb2.EchoRequest
def __init__(self, echo: _Optional[_Union[_echo_pb2.EchoRequest, _Mapping]] = ...) -> None: ... def __init__(self, echo: _Optional[_Union[_echo_pb2.EchoRequest, _Mapping]] = ...) -> None: ...
class TargetMessage(_message.Message): class Response(_message.Message):
__slots__ = ["echo", "test"] __slots__ = ["echo", "test"]
ECHO_FIELD_NUMBER: _ClassVar[int] ECHO_FIELD_NUMBER: _ClassVar[int]
TEST_FIELD_NUMBER: _ClassVar[int] TEST_FIELD_NUMBER: _ClassVar[int]

View file

@ -28,7 +28,7 @@ from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent
from prompt_toolkit.patch_stdout import patch_stdout from prompt_toolkit.patch_stdout import patch_stdout
from prompt_toolkit.shortcuts import PromptSession from prompt_toolkit.shortcuts import PromptSession
from api_pb2 import HostMessage, TargetMessage from api_pb2 import Message, Request, Response
SYNC_BYTE = b"\xfc" SYNC_BYTE = b"\xfc"
@ -37,8 +37,10 @@ CRC_SIZE = 2
uint16_le = struct.Struct("<H") uint16_le = struct.Struct("<H")
incoming_messages = asyncio.Queue[TargetMessage]() MessageQueue = asyncio.Queue[Message]
outgoing_messages = asyncio.Queue[HostMessage]()
incoming_messages = asyncio.Queue[Message]()
outgoing_messages = asyncio.Queue[Message]()
echo_response_queue = asyncio.Queue[int](maxsize=1) echo_response_queue = asyncio.Queue[int](maxsize=1)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -150,7 +152,7 @@ async def proto_listener(stream: asyncio.StreamReader) -> None:
input_buf.clear() input_buf.clear()
msg = TargetMessage() msg = Message()
try: try:
msg.ParseFromString(message_bytes) msg.ParseFromString(message_bytes)
@ -212,10 +214,18 @@ async def message_handler() -> None:
msg = await incoming_messages.get() msg = await incoming_messages.get()
try: try:
if not msg.HasField("msg"): if not msg.HasField("inner"):
invalid_message("missing 'msg' field") invalid_message("missing 'inner' field")
await handle_target_message(msg) match msg.WhichOneof("inner"):
case "req":
await handle_request(msg.req)
case "resp":
await handle_response(msg.resp)
case _:
invalid_message("unknown message type")
except InvalidMessageError as ex: except InvalidMessageError as ex:
logger.error(ex) # noqa: TRY400 logger.error(ex) # noqa: TRY400
@ -226,10 +236,30 @@ async def message_handler() -> None:
logger.info("Message Handler task cancelled.") logger.info("Message Handler task cancelled.")
async def handle_target_message(tgt: TargetMessage) -> None: async def handle_request(req: Request) -> None:
match tgt.WhichOneof("msg"): if not req.HasField("msg"):
invalid_message("request: missing 'msg' field")
match req.WhichOneof("msg"):
case "echo": case "echo":
await echo_response_queue.put(tgt.echo.data) if not req.echo.HasField("data"):
invalid_message("request: echo: missing 'data' field")
msg = Message()
msg.resp.echo.data = req.echo.data
await outgoing_messages.put(msg)
case _:
invalid_message("request: unknown type")
async def handle_response(resp: Response) -> None:
if not resp.HasField("msg"):
invalid_message("response: missing 'msg' field")
match resp.WhichOneof("msg"):
case "echo":
await echo_response_queue.put(resp.echo.data)
case "test": case "test":
pass pass
@ -241,8 +271,8 @@ async def handle_target_message(tgt: TargetMessage) -> None:
async def message_sender() -> None: async def message_sender() -> None:
try: try:
while True: while True:
msg = HostMessage() msg = Message()
msg.echo.data = random.randint(0, 2**22) msg.req.echo.data = random.randint(0, 2**22)
await outgoing_messages.put(msg) await outgoing_messages.put(msg)
@ -253,9 +283,9 @@ async def message_sender() -> None:
logger.error("Timeout waiting for echo response") # noqa: TRY400 logger.error("Timeout waiting for echo response") # noqa: TRY400
else: else:
if response != msg.echo.data: if response != msg.req.echo.data:
logger.error( logger.error(
"Incorrect echo response: expected %d - got %d", msg.echo.data, response "Incorrect echo response: expected %d - got %d", msg.req.echo.data, response
) )
await asyncio.sleep(5) await asyncio.sleep(5)