Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 90 additions & 10 deletions usb_protocol/types/descriptors/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,23 @@ class DeviceCapabilityTypes(IntEnum):
)


EndpointDescriptor = DescriptorFormat(
EndpointDescriptorLength = construct.Rebuild(construct.Int8ul, 7 if (this.bRefresh is None) and (this.bSynchAddress is None) else 9)

EndpointDescriptor = DescriptorFormat(
# [USB2.0: 9.6; USB Audio Device Class Definition 1.0: 4.6.1.1, 4.6.2.1]
# Interfaces of the Audio 1.0 class extend their subordinate endpoint descriptors with
# 2 additional bytes (extending it from 7 to 9 bytes). Thankfully, this is the only extension that
# changes the length of a standard descriptor type, but we do have to handle this case in Construct.
"bLength" / construct.Default(construct.OneOf(construct.Int8ul, [7, 9]), 7),
"bLength" / EndpointDescriptorLength,
"bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.ENDPOINT),
"bEndpointAddress" / DescriptorField("Endpoint Address"),
"bmAttributes" / DescriptorField("Attributes", default=2),
"wMaxPacketSize" / DescriptorField("Maximum Packet Size", default=64),
"bInterval" / DescriptorField("Polling interval", default=255),

# 2 bytes that are only present on endpoint descriptors for Audio 1.0 class interfaces.
("bRefresh" / construct.Optional(construct.Int8ul)) * "Refresh Rate",
("bSynchAddress" / construct.Optional(construct.Int8ul)) * "Synch Endpoint Address",
("bRefresh" / construct.If(this.bLength == 9, construct.Optional(construct.Int8ul))) * "Refresh Rate",
("bSynchAddress" / construct.If(this.bLength == 9, construct.Optional(construct.Int8ul))) * "Synch Endpoint Address",
)


Expand Down Expand Up @@ -198,7 +199,6 @@ class DeviceCapabilityTypes(IntEnum):
)



class DescriptorParserCases(unittest.TestCase):

STRING_DESCRIPTOR = bytes([
Expand All @@ -225,7 +225,6 @@ class DescriptorParserCases(unittest.TestCase):
ord('s'), 0x00,
])


def test_string_descriptor_parse(self):

# Parse the relevant string...
Expand All @@ -236,23 +235,20 @@ def test_string_descriptor_parse(self):
self.assertEqual(parsed.bDescriptorType, 3)
self.assertEqual(parsed.bString, "Great Scott Gadgets")


def test_string_descriptor_build(self):
data = StringDescriptor.build({
'bString': "Great Scott Gadgets"
})

self.assertEqual(data, self.STRING_DESCRIPTOR)


def test_string_language_descriptor_build(self):
data = StringLanguageDescriptor.build({
'wLANGID': (LanguageIDs.ENGLISH_US,)
})

self.assertEqual(data, b"\x04\x03\x09\x04")


def test_device_descriptor(self):

device_descriptor = [
Expand Down Expand Up @@ -291,7 +287,6 @@ def test_device_descriptor(self):
self.assertEqual(parsed.iSerialNumber, 3)
self.assertEqual(parsed.bNumConfigurations, 1)


def test_bcd_constructor(self):

emitter = BCDFieldAdapter(construct.Int16ul)
Expand All @@ -300,5 +295,90 @@ def test_bcd_constructor(self):
self.assertEqual(result, b"\x40\x01")


def test_parse_endpoint_descriptor(self):
# Parse the relevant descriptor ...
parsed = EndpointDescriptor.parse([
0x07, # Length
0x05, # Type
0x81, # Endpoint address
0x02, # Attributes
0x40, 0x00, # Maximum packet size
0xFF, # Interval
])

# ... and check the descriptor's fields.
self.assertEqual(parsed.bLength, 7)
self.assertEqual(parsed.bDescriptorType, StandardDescriptorNumbers.ENDPOINT)
self.assertEqual(parsed.bEndpointAddress, 0x81)
self.assertEqual(parsed.bmAttributes, 2)
self.assertEqual(parsed.wMaxPacketSize, 64)
self.assertEqual(parsed.bInterval, 255)

def test_build_endpoint_descriptor(self):
# Build the relevant descriptor
data = EndpointDescriptor.build({
'bEndpointAddress': 0x81,
'bmAttributes': 2,
'wMaxPacketSize': 64,
'bInterval': 255,
})

# ... and check the binary output
self.assertEqual(data, bytes([
0x09, # Length
0x05, # Type
0x81, # Endpoint address
0x02, # Attributes
0x40, 0x00, # Maximum packet size
0xFF, # Interval
]))

def test_parse_endpoint_descriptor_audio(self):
# Parse the relevant descriptor ...
parsed = EndpointDescriptor.parse([
0x09, # Length
0x05, # Type
0x81, # Endpoint address
0x02, # Attributes
0x40, 0x00, # Maximum packet size
0xFF, # Interval
0x20, # Refresh rate
0x05, # Synch endpoint address
])

# ... and check the descriptor's fields.
self.assertEqual(parsed.bLength, 9)
self.assertEqual(parsed.bDescriptorType, StandardDescriptorNumbers.ENDPOINT)
self.assertEqual(parsed.bEndpointAddress, 0x81)
self.assertEqual(parsed.bmAttributes, 2)
self.assertEqual(parsed.wMaxPacketSize, 64)
self.assertEqual(parsed.bInterval, 255)
self.assertEqual(parsed.bRefresh, 32)
self.assertEqual(parsed.bSynchAddress, 0x05)

def test_build_endpoint_descriptor_audio(self):
# Build the relevant descriptor
data = EndpointDescriptor.build({
'bEndpointAddress': 0x81,
'bmAttributes': 2,
'wMaxPacketSize': 64,
'bInterval': 255,
'bRefresh': 32,
'bSynchAddress': 0x05,
})

# ... and check the binary output
self.assertEqual(data, bytes([
0x09, # Length
0x05, # Type
0x81, # Endpoint address
0x02, # Attributes
0x40, 0x00, # Maximum packet size
0xFF, # Interval
0x20, # Refresh rate
0x05, # Synch endpoint address
]))


if __name__ == "__main__":
unittest.main()