diff --git a/src/machine/machine_atsamd21_usb.go b/src/machine/machine_atsamd21_usb.go index 7b9d2e14f8..e587968d8e 100644 --- a/src/machine/machine_atsamd21_usb.go +++ b/src/machine/machine_atsamd21_usb.go @@ -336,7 +336,7 @@ func handleUSBSetAddress(setup usb.Setup) bool { } // SendUSBInPacket sends a packet for USB (interrupt in / bulk in). -func SendUSBInPacket(ep uint32, data []byte) bool { +func (dev *USBDevice) SendUSBInPacket(ep uint32, data []byte) bool { sendUSBPacket(ep, data, 0) // clear transfer complete flag @@ -374,7 +374,7 @@ func sendUSBPacket(ep uint32, data []byte, maxsize uint16) { usbEndpointDescriptors[ep].DeviceDescBank[1].PCKSIZE.SetBits((uint32(l) & usb_DEVICE_PCKSIZE_BYTE_COUNT_Mask) << usb_DEVICE_PCKSIZE_BYTE_COUNT_Pos) } -func ReceiveUSBControlPacket() ([cdcLineInfoSize]byte, error) { +func (dev *USBDevice) ReceiveUSBControlPacket() ([cdcLineInfoSize]byte, error) { var b [cdcLineInfoSize]byte // Wait until OUT transfer is ready. @@ -417,7 +417,7 @@ func handleEndpointRx(ep uint32) []byte { } // AckUsbOutTransfer is called to acknowledge the completion of a USB OUT transfer. -func AckUsbOutTransfer(ep uint32) { +func (dev *USBDevice) AckUsbOutTransfer(ep uint32) { // set byte count to zero usbEndpointDescriptors[ep].DeviceDescBank[0].PCKSIZE.ClearBits(usb_DEVICE_PCKSIZE_BYTE_COUNT_Mask << usb_DEVICE_PCKSIZE_BYTE_COUNT_Pos) @@ -426,10 +426,9 @@ func AckUsbOutTransfer(ep uint32) { // set ready for next data setEPSTATUSCLR(ep, sam.USB_DEVICE_EPSTATUSCLR_BK0RDY) - } -func SendZlp() { +func (dev *USBDevice) SendZlp() { usbEndpointDescriptors[0].DeviceDescBank[1].PCKSIZE.ClearBits(usb_DEVICE_PCKSIZE_BYTE_COUNT_Mask << usb_DEVICE_PCKSIZE_BYTE_COUNT_Pos) } @@ -662,3 +661,23 @@ func setEPINTENSET(ep uint32, val uint8) { return } } + +// Set ENDPOINT_HALT/stall status on a USB IN endpoint. +func (dev *USBDevice) SetStallEPIn(ep uint32) { + setEPSTATUSSET(ep, sam.USB_DEVICE_EPSTATUSSET_STALLRQ1) +} + +// Set ENDPOINT_HALT/stall status on a USB OUT endpoint. +func (dev *USBDevice) SetStallEPOut(ep uint32) { + setEPSTATUSSET(ep, sam.USB_DEVICE_EPSTATUSSET_STALLRQ0) +} + +// Clear the ENDPOINT_HALT/stall on a USB IN endpoint. +func (dev *USBDevice) ClearStallEPIn(ep uint32) { + setEPSTATUSCLR(ep, sam.USB_DEVICE_EPSTATUSCLR_STALLRQ1) +} + +// Clear the ENDPOINT_HALT/stall on a USB OUT endpoint. +func (dev *USBDevice) ClearStallEPOut(ep uint32) { + setEPSTATUSCLR(ep, sam.USB_DEVICE_EPSTATUSCLR_STALLRQ0) +} diff --git a/src/machine/machine_atsamd51_usb.go b/src/machine/machine_atsamd51_usb.go index a95089f75e..c8299ecc48 100644 --- a/src/machine/machine_atsamd51_usb.go +++ b/src/machine/machine_atsamd51_usb.go @@ -339,7 +339,7 @@ func handleUSBSetAddress(setup usb.Setup) bool { } // SendUSBInPacket sends a packet for USB (interrupt in / bulk in). -func SendUSBInPacket(ep uint32, data []byte) bool { +func (dev *USBDevice) SendUSBInPacket(ep uint32, data []byte) bool { sendUSBPacket(ep, data, 0) // clear transfer complete flag @@ -377,7 +377,7 @@ func sendUSBPacket(ep uint32, data []byte, maxsize uint16) { usbEndpointDescriptors[ep].DeviceDescBank[1].PCKSIZE.SetBits((uint32(l) & usb_DEVICE_PCKSIZE_BYTE_COUNT_Mask) << usb_DEVICE_PCKSIZE_BYTE_COUNT_Pos) } -func ReceiveUSBControlPacket() ([cdcLineInfoSize]byte, error) { +func (dev *USBDevice) ReceiveUSBControlPacket() ([cdcLineInfoSize]byte, error) { var b [cdcLineInfoSize]byte // Wait until OUT transfer is ready. @@ -420,7 +420,7 @@ func handleEndpointRx(ep uint32) []byte { } // AckUsbOutTransfer is called to acknowledge the completion of a USB OUT transfer. -func AckUsbOutTransfer(ep uint32) { +func (dev *USBDevice) AckUsbOutTransfer(ep uint32) { // set byte count to zero usbEndpointDescriptors[ep].DeviceDescBank[0].PCKSIZE.ClearBits(usb_DEVICE_PCKSIZE_BYTE_COUNT_Mask << usb_DEVICE_PCKSIZE_BYTE_COUNT_Pos) @@ -431,7 +431,7 @@ func AckUsbOutTransfer(ep uint32) { setEPSTATUSCLR(ep, sam.USB_DEVICE_ENDPOINT_EPSTATUSCLR_BK0RDY) } -func SendZlp() { +func (dev *USBDevice) SendZlp() { usbEndpointDescriptors[0].DeviceDescBank[1].PCKSIZE.ClearBits(usb_DEVICE_PCKSIZE_BYTE_COUNT_Mask << usb_DEVICE_PCKSIZE_BYTE_COUNT_Pos) } @@ -493,3 +493,23 @@ func setEPINTENCLR(ep uint32, val uint8) { func setEPINTENSET(ep uint32, val uint8) { sam.USB_DEVICE.DEVICE_ENDPOINT[ep].EPINTENSET.Set(val) } + +// Set ENDPOINT_HALT/stall status on a USB IN endpoint. +func (dev *USBDevice) SetStallEPIn(ep uint32) { + setEPSTATUSSET(ep, sam.USB_DEVICE_ENDPOINT_EPSTATUSSET_STALLRQ1) +} + +// Set ENDPOINT_HALT/stall status on a USB OUT endpoint. +func (dev *USBDevice) SetStallEPOut(ep uint32) { + setEPSTATUSSET(ep, sam.USB_DEVICE_ENDPOINT_EPSTATUSSET_STALLRQ0) +} + +// Clear the ENDPOINT_HALT/stall on a USB IN endpoint. +func (dev *USBDevice) ClearStallEPIn(ep uint32) { + setEPSTATUSCLR(ep, sam.USB_DEVICE_ENDPOINT_EPSTATUSCLR_STALLRQ1) +} + +// Clear the ENDPOINT_HALT/stall on a USB OUT endpoint. +func (dev *USBDevice) ClearStallEPOut(ep uint32) { + setEPSTATUSCLR(ep, sam.USB_DEVICE_ENDPOINT_EPSTATUSCLR_STALLRQ0) +} diff --git a/src/machine/machine_nrf52840_usb.go b/src/machine/machine_nrf52840_usb.go index 1fa46945fa..b5d6bff8eb 100644 --- a/src/machine/machine_nrf52840_usb.go +++ b/src/machine/machine_nrf52840_usb.go @@ -255,12 +255,9 @@ func initEndpoint(ep, config uint32) { } // SendUSBInPacket sends a packet for USBHID (interrupt in / bulk in). -func SendUSBInPacket(ep uint32, data []byte) bool { +func (dev *USBDevice) SendUSBInPacket(ep uint32, data []byte) bool { sendUSBPacket(ep, data, 0) - // clear transfer complete flag - nrf.USBD.INTENCLR.Set(nrf.USBD_INTENCLR_ENDEPOUT0 << 4) - return true } @@ -304,15 +301,64 @@ func handleEndpointRx(ep uint32) []byte { } // AckUsbOutTransfer is called to acknowledge the completion of a USB OUT transfer. -func AckUsbOutTransfer(ep uint32) { +func (dev *USBDevice) AckUsbOutTransfer(ep uint32) { // set ready for next data nrf.USBD.SIZE.EPOUT[ep].Set(0) } -func SendZlp() { +func (dev *USBDevice) SendZlp() { nrf.USBD.TASKS_EP0STATUS.Set(1) } +// Set the USB endpoint Packet ID to DATA0 or DATA1. +// In endpoints must have bit 7 (0x80) set. +func setEPDataPID(ep uint32, dataOne bool) { + // nrf52840 DTOGGLE requires a "Select" write first (Value=Nop=0), + // then a "Set" write (Value=Data0/Data1). + + // Select Endpoint (Value=Nop=0) + nrf.USBD.DTOGGLE.Set(ep) + + // Now write the value + val := ep + if dataOne { + val |= nrf.USBD_DTOGGLE_VALUE_Data1 << nrf.USBD_DTOGGLE_VALUE_Pos + } else { + val |= nrf.USBD_DTOGGLE_VALUE_Data0 << nrf.USBD_DTOGGLE_VALUE_Pos + } + nrf.USBD.DTOGGLE.Set(val) +} + +// Set ENDPOINT_HALT/stall status on a USB IN endpoint. +func (dev *USBDevice) SetStallEPIn(ep uint32) { + // Bit 8 is STALL, Bit 7 is IO (1 for IN), Bits 0-2 are EP number. + nrf.USBD.EPSTALL.Set((1 << 8) | (1 << 7) | (ep & 0x7)) +} + +// Set ENDPOINT_HALT/stall status on a USB OUT endpoint. +func (dev *USBDevice) SetStallEPOut(ep uint32) { + // Bit 8 is STALL, Bit 7 is IO (0 for OUT), Bits 0-2 are EP number. + nrf.USBD.EPSTALL.Set((1 << 8) | (0 << 7) | (ep & 0x7)) +} + +// Clear the ENDPOINT_HALT/stall on a USB IN endpoint. +func (dev *USBDevice) ClearStallEPIn(ep uint32) { + // Reset Data Toggle to DATA0 when unstalling. + setEPDataPID(ep|usb.EndpointIn, false) + + // Bit 8 is STALL (0 for UnStall), Bit 7 is IO (1 for IN), Bits 0-2 are EP number. + nrf.USBD.EPSTALL.Set((0 << 8) | (1 << 7) | (ep & 0x7)) +} + +// Clear the ENDPOINT_HALT/stall on a USB OUT endpoint. +func (dev *USBDevice) ClearStallEPOut(ep uint32) { + // Reset Data Toggle to DATA0 when unstalling. + setEPDataPID(ep, false) + + // Bit 8 is STALL (0 for UnStall), Bit 7 is IO (0 for OUT), Bits 0-2 are EP number. + nrf.USBD.EPSTALL.Set((0 << 8) | (0 << 7) | (ep & 0x7)) +} + func sendViaEPIn(ep uint32, ptr *byte, count int) { nrf.USBD.EPIN[ep].PTR.Set( uint32(uintptr(unsafe.Pointer(ptr))), @@ -336,7 +382,7 @@ func handleUSBSetAddress(setup usb.Setup) bool { return true } -func ReceiveUSBControlPacket() ([cdcLineInfoSize]byte, error) { +func (dev *USBDevice) ReceiveUSBControlPacket() ([cdcLineInfoSize]byte, error) { var b [cdcLineInfoSize]byte nrf.USBD.TASKS_EP0RCVOUT.Set(1) diff --git a/src/machine/machine_rp2_usb.go b/src/machine/machine_rp2_usb.go index 297cc9d9cf..0f43e61d0e 100644 --- a/src/machine/machine_rp2_usb.go +++ b/src/machine/machine_rp2_usb.go @@ -71,7 +71,7 @@ func initEndpoint(ep, config uint32) { } // SendUSBInPacket sends a packet for USB (interrupt in / bulk in). -func SendUSBInPacket(ep uint32, data []byte) bool { +func (dev *USBDevice) SendUSBInPacket(ep uint32, data []byte) bool { sendUSBPacket(ep, data, 0) return true } @@ -100,7 +100,7 @@ func sendUSBPacket(ep uint32, data []byte, maxsize uint16) { sendViaEPIn(ep, data, count) } -func ReceiveUSBControlPacket() ([cdcLineInfoSize]byte, error) { +func (dev *USBDevice) ReceiveUSBControlPacket() ([cdcLineInfoSize]byte, error) { var b [cdcLineInfoSize]byte ep := 0 @@ -129,7 +129,7 @@ func handleEndpointRx(ep uint32) []byte { } // AckUsbOutTransfer is called to acknowledge the completion of a USB OUT transfer. -func AckUsbOutTransfer(ep uint32) { +func (dev *USBDevice) AckUsbOutTransfer(ep uint32) { ep = ep & 0x7F setEPDataPID(ep, !epXdata0[ep]) } @@ -144,7 +144,7 @@ func setEPDataPID(ep uint32, dataOne bool) { _usbDPSRAM.EPxBufferControl[ep].Out.SetBits(usbBuf0CtrlAvail) } -func SendZlp() { +func (dev *USBDevice) SendZlp() { sendUSBPacket(0, []byte{}, 0) } diff --git a/src/machine/usb.go b/src/machine/usb.go index 9682521036..3c20587c4f 100644 --- a/src/machine/usb.go +++ b/src/machine/usb.go @@ -19,9 +19,12 @@ var ( USBCDC Serialer ) +func init() { + usb.DefaultController = USBDev +} + func initUSB() { - enableUSBCDC() - USBDev.Configure(UARTConfig{}) + USBDev.Enable() } // Using go:linkname here because there's a circular dependency between the @@ -30,6 +33,10 @@ func initUSB() { //go:linkname enableUSBCDC machine/usb/cdc.EnableUSBCDC func enableUSBCDC() +func ReceiveUSBControlPacket() ([7]byte, error) { + return USBDev.ReceiveUSBControlPacket() +} + type Serialer interface { WriteByte(c byte) error Write(data []byte) (n int, err error) @@ -285,6 +292,19 @@ func handleStandardSetup(setup usb.Setup) bool { } } +func (d *USBDevice) Enable() { + if d.initcomplete { + return + } + enableUSBCDC() + d.Configure(UARTConfig{}) + d.initcomplete = true +} + +func (d *USBDevice) IsInitEndpointComplete() bool { + return d.InitEndpointComplete +} + func EnableCDC(txHandler func(), rxHandler func([]byte), setupHandler func(usb.Setup) bool) { if len(usbDescriptor.Device) == 0 { usbDescriptor = descriptor.CDC @@ -319,6 +339,10 @@ func EnableCDC(txHandler func(), rxHandler func([]byte), setupHandler func(usb.S } func ConfigureUSBEndpoint(desc descriptor.Descriptor, epSettings []usb.EndpointConfig, setup []usb.SetupConfig) { + USBDev.ConfigureUSBEndpoint(desc, epSettings, setup) +} + +func (d *USBDevice) ConfigureUSBEndpoint(desc descriptor.Descriptor, epSettings []usb.EndpointConfig, setup []usb.SetupConfig) { usbDescriptor = desc for _, ep := range epSettings { @@ -347,3 +371,16 @@ func ConfigureUSBEndpoint(desc descriptor.Descriptor, epSettings []usb.EndpointC usbSetupHandler[s.Index] = s.Handler } } + +// Old usb functions kept for compatibility +func AckUsbOutTransfer(ep uint32) { + USBDev.AckUsbOutTransfer(ep) +} + +func SendUSBInPacket(ep uint32, data []byte) bool { + return USBDev.SendUSBInPacket(ep, data) +} + +func SendZlp() { + USBDev.SendZlp() +} diff --git a/src/machine/usb/cdc/cdc.go b/src/machine/usb/cdc/cdc.go index f180535df1..a82323a9fd 100644 --- a/src/machine/usb/cdc/cdc.go +++ b/src/machine/usb/cdc/cdc.go @@ -1,5 +1,7 @@ package cdc +import "machine/usb" + const ( cdcEndpointACM = 1 cdcEndpointOut = 2 @@ -12,6 +14,7 @@ func New() *USBCDC { USB = &USBCDC{ rxBuffer: NewRxRingBuffer(), txBuffer: NewTxRingBuffer(), + dev: usb.DefaultController, } } return USB diff --git a/src/machine/usb/cdc/usbcdc.go b/src/machine/usb/cdc/usbcdc.go index 5b5ffbf7c4..512535efad 100644 --- a/src/machine/usb/cdc/usbcdc.go +++ b/src/machine/usb/cdc/usbcdc.go @@ -70,6 +70,7 @@ type USBCDC struct { rxBuffer *rxRingBuffer txBuffer *txRingBuffer waitTxc bool + dev usb.Controller } var ( @@ -88,7 +89,7 @@ func (usbcdc *USBCDC) Configure(config machine.UARTConfig) error { func (usbcdc *USBCDC) Flush() { mask := interrupt.Disable() if b, ok := usbcdc.txBuffer.Get(); ok { - machine.SendUSBInPacket(cdcEndpointIn, b) + usbcdc.dev.SendUSBInPacket(cdcEndpointIn, b) } else { usbcdc.waitTxc = false } @@ -123,15 +124,15 @@ func (usbcdc *USBCDC) RTS() bool { return (usbLineInfo.lineState & usb_CDC_LINESTATE_RTS) > 0 } -func cdcCallbackRx(b []byte) { +func (usbcdc *USBCDC) Rx(b []byte) { for i := range b { - USB.Receive(b[i]) + usbcdc.Receive(b[i]) } } var cdcSetupBuff [cdcLineInfoSize]byte -func cdcSetup(setup usb.Setup) bool { +func (usbcdc *USBCDC) Setup(setup usb.Setup) bool { if setup.BmRequestType == usb_REQUEST_DEVICETOHOST_CLASS_INTERFACE { if setup.BRequest == usb_CDC_GET_LINE_CODING { cdcSetupBuff[0] = byte(usbLineInfo.dwDTERate) @@ -142,14 +143,14 @@ func cdcSetup(setup usb.Setup) bool { cdcSetupBuff[5] = byte(usbLineInfo.bParityType) cdcSetupBuff[6] = byte(usbLineInfo.bDataBits) - machine.SendUSBInPacket(0, cdcSetupBuff[:]) + usbcdc.dev.SendUSBInPacket(0, cdcSetupBuff[:]) return true } } if setup.BmRequestType == usb_REQUEST_HOSTTODEVICE_CLASS_INTERFACE { if setup.BRequest == usb_CDC_SET_LINE_CODING { - b, err := machine.ReceiveUSBControlPacket() + b, err := usbcdc.dev.ReceiveUSBControlPacket() if err != nil { return false } @@ -171,14 +172,14 @@ func cdcSetup(setup usb.Setup) bool { } else { // TODO: cancel any reset } - machine.SendZlp() + usbcdc.dev.SendZlp() } if setup.BRequest == usb_CDC_SEND_BREAK { // TODO: something with this value? // breakValue = ((uint16_t)setup.wValueH << 8) | setup.wValueL; // return false; - machine.SendZlp() + usbcdc.dev.SendZlp() } return true } @@ -186,6 +187,7 @@ func cdcSetup(setup usb.Setup) bool { } func EnableUSBCDC() { - machine.USBCDC = New() - machine.EnableCDC(USB.Flush, cdcCallbackRx, cdcSetup) + c := New() + machine.USBCDC = c + machine.EnableCDC(c.Flush, c.Rx, c.Setup) } diff --git a/src/machine/usb/device.go b/src/machine/usb/device.go new file mode 100644 index 0000000000..34dcf43976 --- /dev/null +++ b/src/machine/usb/device.go @@ -0,0 +1,22 @@ +package usb + +import ( + "machine/usb/descriptor" +) + +// Controller abstracts the USB interactions to allow for testing without hardware. +type Controller interface { + Enable() + ConfigureUSBEndpoint(desc descriptor.Descriptor, epSettings []EndpointConfig, setup []SetupConfig) + SendUSBInPacket(ep uint32, data []byte) bool + AckUsbOutTransfer(ep uint32) + SendZlp() + IsInitEndpointComplete() bool + SetStallEPIn(ep uint32) + SetStallEPOut(ep uint32) + ClearStallEPIn(ep uint32) + ClearStallEPOut(ep uint32) + ReceiveUSBControlPacket() ([7]byte, error) +} + +var DefaultController Controller diff --git a/src/machine/usb/msc/disk.go b/src/machine/usb/msc/disk.go index 6624d38c01..1b19117cf2 100644 --- a/src/machine/usb/msc/disk.go +++ b/src/machine/usb/msc/disk.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "errors" "fmt" - "machine" "time" ) @@ -13,7 +12,7 @@ var ( ) // RegisterBlockDevice registers a BlockDevice provider with the MSC driver -func (m *msc) RegisterBlockDevice(dev machine.BlockDevice) { +func (m *msc) RegisterBlockDevice(dev BlockDevice) { m.dev = dev if cap(m.blockCache) != int(dev.WriteBlockSize()) { @@ -56,11 +55,11 @@ func (m *msc) RegisterBlockDevice(dev machine.BlockDevice) { } } -var _ machine.BlockDevice = (*RecorderDisk)(nil) +var _ BlockDevice = (*RecorderDisk)(nil) // RecorderDisk is a block device that records actions taken on it type RecorderDisk struct { - dev machine.BlockDevice + dev BlockDevice log []RecorderRecord last time.Time time time.Time @@ -84,7 +83,7 @@ const ( ) // NewRecorderDisk creates a new RecorderDisk instance -func NewRecorderDisk(dev machine.BlockDevice, count int) *RecorderDisk { +func NewRecorderDisk(dev BlockDevice, count int) *RecorderDisk { d := &RecorderDisk{ dev: dev, log: make([]RecorderRecord, 0, count), diff --git a/src/machine/usb/msc/interfaces.go b/src/machine/usb/msc/interfaces.go new file mode 100644 index 0000000000..2c382e9094 --- /dev/null +++ b/src/machine/usb/msc/interfaces.go @@ -0,0 +1,34 @@ +package msc + +import ( + "io" +) + +// BlockDevice is the raw device that is meant to store flash data. +// It mimics the interface defined in machine/flash.go to allow for decoupling. +type BlockDevice interface { + // ReadAt reads the given number of bytes from the block device. + io.ReaderAt + + // WriteAt writes the given number of bytes to the block device. + io.WriterAt + + // Size returns the number of bytes in this block device. + Size() int64 + + // WriteBlockSize returns the block size in which data can be written to + // memory. It can be used by a client to optimize writes, non-aligned writes + // should always work correctly. + WriteBlockSize() int64 + + // EraseBlockSize returns the smallest erasable area on this particular chip + // in bytes. This is used for the block size in EraseBlocks. + // It must be a power of two, and may be as small as 1. A typical size is 4096. + EraseBlockSize() int64 + + // EraseBlocks erases the given number of blocks. An implementation may + // transparently coalesce ranges of blocks into larger bundles if the chip + // supports this. The start and len parameters are in block numbers, use + // EraseBlockSize to map addresses to blocks. + EraseBlocks(start, len int64) error +} diff --git a/src/machine/usb/msc/mock_usb_device.go b/src/machine/usb/msc/mock_usb_device.go new file mode 100644 index 0000000000..544435356d --- /dev/null +++ b/src/machine/usb/msc/mock_usb_device.go @@ -0,0 +1,63 @@ +package msc + +import ( + "machine/usb" + "machine/usb/descriptor" +) + +// MockUSBDevice implements usb.Controller for testing. +type MockUSBDevice struct { + InPackets [][]byte + OutAcked bool + StallIn bool + StallOut bool + InitEndpointDone bool + ZlpSent bool +} + +func (m *MockUSBDevice) ConfigureUSBEndpoint(desc descriptor.Descriptor, epSettings []usb.EndpointConfig, setup []usb.SetupConfig) { +} + +func (m *MockUSBDevice) SendUSBInPacket(ep uint32, data []byte) bool { + // Copy data to avoid modification issues if buffer is reused + packet := make([]byte, len(data)) + copy(packet, data) + m.InPackets = append(m.InPackets, packet) + return true +} + +func (m *MockUSBDevice) AckUsbOutTransfer(ep uint32) { + m.OutAcked = true +} + +func (m *MockUSBDevice) SendZlp() { + m.ZlpSent = true + m.SendUSBInPacket(0, []byte{}) +} + +func (m *MockUSBDevice) IsInitEndpointComplete() bool { + return m.InitEndpointDone +} + +func (m *MockUSBDevice) SetStallEPIn(ep uint32) { + m.StallIn = true +} + +func (m *MockUSBDevice) SetStallEPOut(ep uint32) { + m.StallOut = true +} + +func (m *MockUSBDevice) ClearStallEPIn(ep uint32) { + m.StallIn = false +} + +func (m *MockUSBDevice) ClearStallEPOut(ep uint32) { + m.StallOut = false +} + +func (m *MockUSBDevice) Enable() { +} + +func (m *MockUSBDevice) ReceiveUSBControlPacket() ([7]byte, error) { + return [7]byte{}, nil +} diff --git a/src/machine/usb/msc/msc.go b/src/machine/usb/msc/msc.go index 420a06ed98..ae511663f9 100644 --- a/src/machine/usb/msc/msc.go +++ b/src/machine/usb/msc/msc.go @@ -1,7 +1,6 @@ package msc import ( - "machine" "machine/usb" "machine/usb/descriptor" "machine/usb/msc/csw" @@ -43,7 +42,8 @@ type msc struct { state mscState maxLUN uint8 // Maximum Logical Unit Number (n-1 for n LUNs) - dev machine.BlockDevice + dev BlockDevice + usb usb.Controller blockCount uint32 // Number of blocks in the device blockOffset uint32 // Byte offset of the first block in the device for aligned writes blockSizeUSB uint32 // Write block size as presented to the host over USB @@ -60,14 +60,14 @@ type msc struct { } // Port returns the USB Mass Storage port -func Port(dev machine.BlockDevice) *msc { +func Port(dev BlockDevice) *msc { if MSC == nil { - MSC = newMSC(dev) + MSC = newMSC(dev, usb.DefaultController) } return MSC } -func newMSC(dev machine.BlockDevice) *msc { +func newMSC(dev BlockDevice, usbCtrl usb.Controller) *msc { // Size our buffer to match the maximum packet size of the IN endpoint maxPacketSize := descriptor.EndpointMSCIN.GetMaxPacketSize() m := &msc{ @@ -78,7 +78,9 @@ func newMSC(dev machine.BlockDevice) *msc { cswBuf: make([]byte, csw.MsgLen), cbw: &CBW{Data: make([]byte, 31)}, maxPacketSize: uint32(maxPacketSize), + usb: usbCtrl, } + m.usb.Enable() m.RegisterBlockDevice(dev) // Set default inquiry data fields @@ -87,7 +89,7 @@ func newMSC(dev machine.BlockDevice) *msc { m.SetProductRev("1.0") // Initialize the USB Mass Storage Class (MSC) port - machine.ConfigureUSBEndpoint(descriptor.MSC, + m.usb.ConfigureUSBEndpoint(descriptor.MSC, []usb.EndpointConfig{ { Index: usb.MSC_ENDPOINT_IN, @@ -132,7 +134,7 @@ func (m *msc) processTasks() { // Acknowledge the received data from the host m.queuedBytes = 0 m.taskQueued = false - machine.AckUsbOutTransfer(usb.MSC_ENDPOINT_OUT) + m.usb.AckUsbOutTransfer(usb.MSC_ENDPOINT_OUT) } time.Sleep(100 * time.Microsecond) } @@ -151,9 +153,9 @@ func (m *msc) resetBuffer(length int) { } func (m *msc) sendUSBPacket(b []byte) { - if machine.USBDev.InitEndpointComplete { + if m.usb.IsInitEndpointComplete() { // Send the USB packet - machine.SendUSBInPacket(usb.MSC_ENDPOINT_IN, b) + m.usb.SendUSBInPacket(usb.MSC_ENDPOINT_IN, b) } } @@ -161,11 +163,12 @@ func (m *msc) sendCSW(status csw.Status) { // Generate CSW packet into m.cswBuf and send it residue := uint32(0) expected := m.cbw.transferLength() - if expected >= m.sentBytes { - residue = expected - m.sentBytes + if expected >= m.sentBytes+m.queuedBytes { + residue = expected - (m.sentBytes + m.queuedBytes) } m.cbw.CSW(status, residue, m.cswBuf) m.state = mscStateStatusSent + m.queuedBytes = csw.MsgLen m.sendUSBPacket(m.cswBuf) } diff --git a/src/machine/usb/msc/msc_test.go b/src/machine/usb/msc/msc_test.go new file mode 100644 index 0000000000..7348e0d60d --- /dev/null +++ b/src/machine/usb/msc/msc_test.go @@ -0,0 +1,217 @@ +package msc + +import ( + "encoding/binary" + "machine/usb" + "machine/usb/msc/csw" + "testing" +) + +// MockBlockDevice implements machine.BlockDevice for testing. +type MockBlockDevice struct { + Data []byte + BlockSize int64 + ReadCount int + WriteCount int + LastWriteAt int64 +} + +func (m *MockBlockDevice) ReadAt(p []byte, off int64) (n int, err error) { + m.ReadCount++ + if off >= int64(len(m.Data)) { + return 0, nil + } + n = copy(p, m.Data[off:]) + return n, nil +} + +func (m *MockBlockDevice) WriteAt(p []byte, off int64) (n int, err error) { + m.WriteCount++ + m.LastWriteAt = off + if off >= int64(len(m.Data)) { + // Expand data if needed + newData := make([]byte, off+int64(len(p))) + copy(newData, m.Data) + m.Data = newData + } + n = copy(m.Data[off:], p) + return n, nil +} + +func (m *MockBlockDevice) Size() int64 { + return int64(len(m.Data)) +} + +func (m *MockBlockDevice) WriteBlockSize() int64 { + return m.BlockSize +} + +func (m *MockBlockDevice) EraseBlockSize() int64 { + return m.BlockSize +} + +func (m *MockBlockDevice) EraseBlocks(start, len int64) error { + return nil +} + +// TestCBWParser verifies that the CBW is correctly parsed. +func TestCBWParser(t *testing.T) { + // Create a valid CBW + // Signature: USBC (0x43425355) + // Tag: 0x12345678 + // Transfer Length: 512 (0x200) + // Flags: 0x80 (IN) + // LUN: 0 + // Length: 10 + // CBD: SCSI Read(10) command (dummy) + cbwData := []byte{ + 0x55, 0x53, 0x42, 0x43, // Signature + 0x78, 0x56, 0x34, 0x12, // Tag + 0x00, 0x02, 0x00, 0x00, // Data Transfer Length (512) + 0x80, // Flags (Direction: IN) + 0x00, // LUN + 0x0A, // CBD Length + // SCSI Command (Read 10) - dummy + 0x28, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Padding + } + + // Setup mocks + mockDev := &MockBlockDevice{Data: make([]byte, 1024), BlockSize: 512} + mockUSB := &MockUSBDevice{InitEndpointDone: true} + + // Initialize MSC + m := newMSC(mockDev, mockUSB) + + // Manually feed the CBW to the run loop + // run(b, true) simulates receiving an OUT packet + ack := m.run(cbwData, true) + + if !ack { + t.Error("Expected ACK for valid CBW") + } + + // Check if CBW was parsed correctly + if m.cbw.Tag() != 0x12345678 { + t.Errorf("Expected Tag 0x12345678, got 0x%x", m.cbw.Tag()) + } + + if m.transferBytes != 512 { + t.Errorf("Expected transferBytes 512, got %d", m.transferBytes) + } + + if m.state != mscStateData { + t.Errorf("Expected state mscStateData, got %d", m.state) + } +} + +// TestResidueLogic simulates a Short Write scenario and verifies the Residue calculation. +func TestResidueLogic(t *testing.T) { + // CBW for WRITE (OUT), 512 bytes + cbwData := []byte{ + 0x55, 0x53, 0x42, 0x43, // Signature + 0x11, 0x22, 0x33, 0x44, // Tag + 0x00, 0x02, 0x00, 0x00, // Data Transfer Length (512) + 0x00, // Flags (Direction: OUT) + 0x00, // LUN + 0x0A, // CBD Length + // SCSI Write(10) command + 0x2A, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Padding + } + + mockDev := &MockBlockDevice{Data: make([]byte, 1024), BlockSize: 512} + mockUSB := &MockUSBDevice{InitEndpointDone: true} + m := newMSC(mockDev, mockUSB) + + // 1. Send CBW + m.run(cbwData, true) + if m.state != mscStateData { + t.Fatalf("Failed to transition to Data state. State: %d", m.state) + } + + // 2. Host sends 256 bytes (half block) + dataPacket := make([]byte, 256) + m.run(dataPacket, true) + + // queuedBytes should be 256. sentBytes (bytes written to block device) should be 0 because we buffer a full block. + if m.queuedBytes != 256 { + t.Errorf("Expected queuedBytes 256, got %d", m.queuedBytes) + } + if m.sentBytes != 0 { + t.Errorf("Expected sentBytes 0, got %d", m.sentBytes) + } + + // Force state to Status to verify Residue calculation + m.state = mscStateStatus + + // Call run to trigger CSW send + m.run([]byte{}, false) // IN endpoint event (dummy) + + // Check if CSW was sent + if len(mockUSB.InPackets) == 0 { + t.Fatal("No CSW sent") + } + + // The last packet should be the CSW + cswPacket := mockUSB.InPackets[len(mockUSB.InPackets)-1] + if len(cswPacket) != csw.MsgLen { + t.Errorf("CSW length mismatch. Expected %d, got %d", csw.MsgLen, len(cswPacket)) + } + + // Parse CSW + // Signature: 0-4 + // Tag: 4-8 + // Residue: 8-12 + // Status: 12 + + signature := binary.LittleEndian.Uint32(cswPacket[:4]) + if signature != csw.Signature { + t.Errorf("Invalid CSW Signature: %x", signature) + } + + tag := binary.LittleEndian.Uint32(cswPacket[4:8]) + if tag != 0x44332211 { // Little Endian of 11 22 33 44 + t.Errorf("Invalid CSW Tag: %x", tag) + } + + residue := binary.LittleEndian.Uint32(cswPacket[8:12]) + // Expected Residue = Expected Length (512) - Processed (256) = 256 + if residue != 256 { + t.Errorf("Incorrect Residue. Expected 256, got %d", residue) + } + + status := cswPacket[12] + if status != byte(csw.StatusPassed) { + t.Errorf("Incorrect Status. Expected %d (Passed), got %d", csw.StatusPassed, status) + } +} + +func TestSetupPacketHandler(t *testing.T) { + mockDev := &MockBlockDevice{Data: make([]byte, 1024), BlockSize: 512} + mockUSB := &MockUSBDevice{InitEndpointDone: true} + m := newMSC(mockDev, mockUSB) + + // Test Get Max LUN (Class Request 0xFE) + setup := usb.Setup{ + BmRequestType: 0xA1, // Device-to-Host, Class, Interface + BRequest: 0xFE, // GET MAX LUN + WValueL: 0, + WValueH: 0, + WIndex: mscInterface, + WLength: 1, + } + + handled := m.setupPacketHandler(setup) + if !handled { + t.Error("Expected GetMaxLUN to be handled") + } + + if len(mockUSB.InPackets) != 1 { + t.Fatalf("Expected 1 IN packet (Max LUN), got %d", len(mockUSB.InPackets)) + } + + if mockUSB.InPackets[0][0] != m.maxLUN { + t.Errorf("Expected MaxLUN %d, got %d", m.maxLUN, mockUSB.InPackets[0][0]) + } +} diff --git a/src/machine/usb/msc/ramdisk.go b/src/machine/usb/msc/ramdisk.go new file mode 100644 index 0000000000..bf44f4b18c --- /dev/null +++ b/src/machine/usb/msc/ramdisk.go @@ -0,0 +1,73 @@ +package msc + +import ( + "errors" +) + +// RamDisk implements machine.BlockDevice in memory. +type RamDisk struct { + Data []byte + BlockSize int64 +} + +// NewRamDisk creates a new RamDisk with the given size. +func NewRamDisk(size int64) *RamDisk { + return &RamDisk{ + Data: make([]byte, size), + BlockSize: 512, + } +} + +// ReadAt reads the given number of bytes from the block device. +func (r *RamDisk) ReadAt(p []byte, off int64) (n int, err error) { + if off >= int64(len(r.Data)) { + return 0, errors.New("read beyond end of ramdisk") + } + n = copy(p, r.Data[off:]) + return n, nil +} + +// WriteAt writes the given number of bytes to the block device. +func (r *RamDisk) WriteAt(p []byte, off int64) (n int, err error) { + if off >= int64(len(r.Data)) { + return 0, errors.New("write beyond end of ramdisk") + } + n = copy(r.Data[off:], p) + if n < len(p) { + return n, errors.New("write beyond end of ramdisk") + } + return n, nil +} + +// Size returns the number of bytes in this block device. +func (r *RamDisk) Size() int64 { + return int64(len(r.Data)) +} + +// WriteBlockSize returns the block size in which data can be written to +// memory. +func (r *RamDisk) WriteBlockSize() int64 { + return r.BlockSize +} + +// EraseBlockSize returns the smallest erasable area on this particular chip +// in bytes. +func (r *RamDisk) EraseBlockSize() int64 { + return r.BlockSize +} + +// EraseBlocks erases the given number of blocks. +func (r *RamDisk) EraseBlocks(start, len int64) error { + // Convert block numbers to byte offsets + startOffset := start * r.EraseBlockSize() + lengthBytes := len * r.EraseBlockSize() + + if startOffset+lengthBytes > int64(cap(r.Data)) { + return errors.New("erase beyond end of ramdisk") + } + + for i := int64(0); i < lengthBytes; i++ { + r.Data[startOffset+i] = 0xFF + } + return nil +} diff --git a/src/machine/usb/msc/scsi.go b/src/machine/usb/msc/scsi.go index 4cec23e2f2..d6e6aefc23 100644 --- a/src/machine/usb/msc/scsi.go +++ b/src/machine/usb/msc/scsi.go @@ -93,6 +93,7 @@ func (m *msc) scsiDataTransfer(b []byte) bool { // Update our sent bytes count to include the just-confirmed bytes m.sentBytes += m.queuedBytes + m.queuedBytes = 0 if m.sentBytes >= m.transferBytes { // Transfer complete, send CSW after transfer confirmed diff --git a/src/machine/usb/msc/setup.go b/src/machine/usb/msc/setup.go index 00507aac69..da35bb1ee4 100644 --- a/src/machine/usb/msc/setup.go +++ b/src/machine/usb/msc/setup.go @@ -1,7 +1,6 @@ package msc import ( - "machine" "machine/usb" ) @@ -58,7 +57,7 @@ func (m *msc) handleClearFeature(setup usb.Setup, wValue uint16) bool { } else if wIndex == usb.MSC_ENDPOINT_OUT { m.stallEndpoint(usb.MSC_ENDPOINT_OUT) } - machine.SendZlp() + m.usb.SendZlp() return true } @@ -81,7 +80,7 @@ func (m *msc) handleClearFeature(setup usb.Setup, wValue uint16) bool { } if ok { - machine.SendZlp() + m.usb.SendZlp() } return ok } @@ -95,7 +94,7 @@ func (m *msc) handleGetMaxLun(setup usb.Setup, wValue uint16) bool { // Send the maximum LUN ID number (zero-indexed, so n-1) supported by the device m.resetBuffer(1) // Shrink buffer to 1 byte m.buf[0] = m.maxLUN - return machine.SendUSBInPacket(usb.CONTROL_ENDPOINT, m.buf) + return m.usb.SendUSBInPacket(usb.CONTROL_ENDPOINT, m.buf) } // 3.1 Bulk-Only Mass Storage Reset @@ -114,7 +113,7 @@ func (m *msc) handleReset(setup usb.Setup, wValue uint16) bool { m.addlSenseQualifier = 0 // Send a zero-length packet (ZLP) to indicate the reset is complete - machine.SendZlp() + m.usb.SendZlp() // Return true to indicate successful reset return true @@ -123,21 +122,21 @@ func (m *msc) handleReset(setup usb.Setup, wValue uint16) bool { func (m *msc) stallEndpoint(ep uint8) { if ep == usb.MSC_ENDPOINT_IN { m.txStalled = true - machine.USBDev.SetStallEPIn(usb.MSC_ENDPOINT_IN) + m.usb.SetStallEPIn(usb.MSC_ENDPOINT_IN) } else if ep == usb.MSC_ENDPOINT_OUT { m.rxStalled = true - machine.USBDev.SetStallEPOut(usb.MSC_ENDPOINT_OUT) + m.usb.SetStallEPOut(usb.MSC_ENDPOINT_OUT) } else if ep == usb.CONTROL_ENDPOINT { - machine.USBDev.SetStallEPIn(usb.CONTROL_ENDPOINT) + m.usb.SetStallEPIn(usb.CONTROL_ENDPOINT) } } func (m *msc) clearStallEndpoint(ep uint8) { if ep == usb.MSC_ENDPOINT_IN { - machine.USBDev.ClearStallEPIn(usb.MSC_ENDPOINT_IN) + m.usb.ClearStallEPIn(usb.MSC_ENDPOINT_IN) m.txStalled = false } else if ep == usb.MSC_ENDPOINT_OUT { - machine.USBDev.ClearStallEPOut(usb.MSC_ENDPOINT_OUT) + m.usb.ClearStallEPOut(usb.MSC_ENDPOINT_OUT) m.rxStalled = false } }