Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 72 additions & 4 deletions applets/axi_writer_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
class Top(Elaboratable):
def __init__(self):
self.reset = ControlSignal()
self.to_write = ControlSignal(reset=32 * 1024 * 1024)
self.to_write = ControlSignal(32)
self.needed_cycles = StatusSignal(32)
self.packet_size = ControlSignal(32, reset=1 * 1024 * 1024)
self.data_counter = StatusSignal(32)
self.data_valid = ControlSignal()
self.packet_counter = StatusSignal(32)
self.data_ready = StatusSignal()

def elaborate(self, platform: ZynqSocPlatform):
Expand All @@ -19,16 +21,82 @@ def elaborate(self, platform: ZynqSocPlatform):

stream = PacketizedStream(64)
m.d.comb += self.data_ready.eq(stream.ready)
m.d.comb += stream.valid.eq(self.data_valid)

done = Signal(reset = 0)

axi_writer = m.submodules.axi_writer = DramPacketRingbufferStreamWriter(stream, max_packet_size=0x1200000, n_buffers=4)
self.axi_writer = axi_writer


with m.If(~done):
m.d.sync += self.needed_cycles.eq(self.needed_cycles + 1)
m.d.comb += stream.valid.eq(1)

with m.If(((self.packet_counter + 1) == self.packet_size) | ((self.data_counter + 1) == self.to_write)):
m.d.comb += stream.last.eq(1)


with m.If(axi_writer.input.ready & axi_writer.input.valid):
m.d.sync += self.data_counter.eq(self.data_counter + 1)
with m.If((self.data_counter + 1) < self.to_write):
m.d.sync += self.data_counter.eq(self.data_counter + 1)
with m.Else():
m.d.sync += self.data_counter.eq(0)
m.d.sync += done.eq(1)

with m.If((self.packet_counter + 1) == self.packet_size):
m.d.sync += self.packet_counter.eq(0)
with m.Else():
m.d.sync += self.packet_counter.eq(self.packet_counter + 1)

m.d.comb += stream.payload.eq(Cat(self.data_counter, self.data_counter + 1000))

return m

@driver_method
def run_and_check(self, to_write = 4 * 1024 * 1024, packet_size = 1 * 1024 * 1024):
self.reset = 1
self.to_write = to_write
self.packet_size = packet_size
self.reset = 0

import time
time.sleep(0.5)

print(f"efficiency: {self.to_write / self.needed_cycles}")

written_buffers = (to_write + packet_size - 1) // packet_size
assert self.axi_writer.buffers_written == written_buffers

base_address = self.axi_writer.base_address
max_buffer = max(self.axi_writer.buffer_base_list_cpu)
map_len = max_buffer + packet_size * 8 - base_address

import mmap, os, sys
mem = mmap.mmap(
os.open('/dev/mem', os.O_RDWR | os.O_SYNC),
map_len, mmap.MAP_SHARED, mmap.PROT_READ | mmap.PROT_WRITE,
offset = base_address
)

last = (written_buffers - 4) * packet_size - 1
offs = written_buffers % self.axi_writer.n_buffers
bufs = self.axi_writer.buffer_base_list_cpu[offs:] + self.axi_writer.buffer_base_list_cpu[:offs]
for buf_addr in bufs:
for w in range(packet_size):
if w % 1000 == 0:
print(".", end="")
sys.stdout.flush()
addr = buf_addr - base_address + 8 * w
val = int.from_bytes(mem[addr:addr+8], 'little')
lower = val & ((1 << 32) - 1)
upper = val >> 32
# print(lower, upper, last)
assert lower == last + 1, f"[{buf_addr:08x}, {w:08x}]: {lower} != {last} + 1"
assert upper == lower + 1000, f"[{buf_addr:0x8}, {w:08x}]: {upper} != {lower} + 1000"
last = lower
print()



if __name__ == "__main__":
cli(Top, runs_on=(MicroR2Platform, BetaPlatform, ZyboPlatform), possible_socs=(ZynqSocPlatform,))
73 changes: 69 additions & 4 deletions naps/cores/dram_packet_ringbuffer/stream_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from naps.cores import AxiReader, AxiWriter, if_none_get_zynq_hp_port, StreamInfo, LastWrapper, StreamTee
from naps import PacketizedStream, BasicStream, stream_transformer, StatusSignal

__all__ = ["DramPacketRingbufferStreamWriter", "DramPacketRingbufferStreamReader"]
__all__ = ["DramPacketRingbufferStreamWriter", "DramPacketRingbufferStreamWriterV2", "DramPacketRingbufferStreamReader"]


class DramPacketRingbufferStreamWriter(Elaboratable):
class DramPacketRingbufferStreamWriterV2(Elaboratable):
def __init__(
self,
input: PacketizedStream,
Expand All @@ -18,7 +17,8 @@ def __init__(
self.n_buffers = n_buffers
self.axi = axi

self.buffer_base_list = Array([base_address + max_packet_size * i for i in range(n_buffers)])
self.buffer_base_list_cpu = [base_address + max_packet_size * i for i in range(n_buffers)]
self.buffer_base_list = Array(self.buffer_base_list_cpu)
self.buffer_level_list = Array([Signal(range(max_packet_size), name=f'buffer{i}_level') for i in range(n_buffers)])
self.current_write_buffer = Signal(range(n_buffers))

Expand Down Expand Up @@ -69,6 +69,71 @@ def elaborate(self, platform):
return m



class DramPacketRingbufferStreamWriter(Elaboratable):
def __init__(
self,
input: PacketizedStream,
max_packet_size, n_buffers, base_address=0x0f80_0000,
axi=None,
):
self.max_packet_size = max_packet_size
self.base_address = base_address
self.n_buffers = n_buffers
self.axi = axi

self.buffer_base_list_cpu = [base_address + max_packet_size * i for i in range(n_buffers)]
self.buffer_base_list = Array(self.buffer_base_list_cpu)
self.buffer_level_list = Array([Signal(range(max_packet_size), name=f'buffer{i}_level') for i in range(n_buffers)])
self.current_write_buffer = Signal(range(n_buffers))

assert hasattr(input, "last")
self.input = input

self.overflowed_buffers = StatusSignal(32)
self.buffers_written = StatusSignal(32)

def elaborate(self, platform):
m = Module()

axi = if_none_get_zynq_hp_port(self.axi, m, platform)
assert len(self.input.payload) <= axi.data_bits

tee = m.submodules.tee = StreamTee(self.input)

data_stream = BasicStream(self.input.payload.shape())
m.d.comb += data_stream.connect_upstream(tee.get_output(), allow_partial=True)

transformer_input = tee.get_output()
address_stream = BasicStream(axi.write_address.payload.shape())
address_offset = Signal.like(axi.write_address.payload)
is_in_overflow = Signal()
stream_transformer(transformer_input, address_stream, m, latency=0, handle_out_of_band=False)

with m.If(transformer_input.ready & transformer_input.valid):
m.d.sync += self.buffer_level_list[self.current_write_buffer].eq(address_offset + axi.data_bytes)
with m.If(transformer_input.last):
m.d.sync += is_in_overflow.eq(0)
next_buffer = (self.current_write_buffer + 1) % self.n_buffers
m.d.sync += address_offset.eq(0)
m.d.sync += self.current_write_buffer.eq(next_buffer)
m.d.sync += self.buffers_written.eq(self.buffers_written + 1)
with m.Else():
with m.If((address_offset + axi.data_bytes < self.max_packet_size)):
m.d.sync += address_offset.eq(address_offset + axi.data_bytes)
with m.Else():
with m.If(~is_in_overflow):
m.d.sync += is_in_overflow.eq(1)
m.d.sync += self.overflowed_buffers.eq(self.overflowed_buffers + 1)
m.d.comb += address_stream.payload.eq(address_offset + self.buffer_base_list[self.current_write_buffer])

m.submodules.writer = AxiWriter(address_stream, data_stream, axi)

m.submodules.input_stream_info = StreamInfo(self.input)

return m


class DramPacketRingbufferStreamReader(Elaboratable):
def __init__(self, writer: DramPacketRingbufferStreamWriter, data_width=64, length_fifo_depth=1, axi=None):
self.writer = writer
Expand Down