chiark / gitweb /
Add some more documentation. And the zhop feature.
[cura.git] / Cura / avr_isp / stk500v2.py
index a66b341db2212d80ad1787285d4468bd64b033f7..975dd30b933d8d89630bdd9dbe8ac078dd39afff 100644 (file)
-import os, struct, sys, time\r
-\r
-from serial import Serial\r
-from serial import SerialException\r
-\r
-import ispBase, intelHex\r
-\r
-class Stk500v2(ispBase.IspBase):\r
-       def __init__(self):\r
-               self.serial = None\r
-               self.seq = 1\r
-               self.lastAddr = -1\r
-               self.progressCallback = None\r
-       \r
-       def connect(self, port = 'COM17', speed = 115200):\r
-               if self.serial != None:\r
-                       self.close()\r
-               try:\r
-                       self.serial = Serial(port, speed, timeout=1, writeTimeout=10000)\r
-               except SerialException as e:\r
-                       raise ispBase.IspError("Failed to open serial port")\r
-               except:\r
-                       raise ispBase.IspError("Unexpected error while connecting to serial port:" + port + ":" + str(sys.exc_info()[0]))\r
-               self.seq = 1\r
-               \r
-               #Reset the controller\r
-               self.serial.setDTR(1)\r
-               time.sleep(0.1)\r
-               self.serial.setDTR(0)\r
-               time.sleep(0.2)\r
-               \r
-               self.sendMessage([1])\r
-               if self.sendMessage([0x10, 0xc8, 0x64, 0x19, 0x20, 0x00, 0x53, 0x03, 0xac, 0x53, 0x00, 0x00]) != [0x10, 0x00]:\r
-                       self.close()\r
-                       raise ispBase.IspError("Failed to enter programming mode")\r
-\r
-       def close(self):\r
-               if self.serial != None:\r
-                       self.serial.close()\r
-                       self.serial = None\r
-\r
-       #Leave ISP does not reset the serial port, only resets the device, and returns the serial port after disconnecting it from the programming interface.\r
-       #       This allows you to use the serial port without opening it again.\r
-       def leaveISP(self):\r
-               if self.serial != None:\r
-                       if self.sendMessage([0x11]) != [0x11, 0x00]:\r
-                               raise ispBase.IspError("Failed to leave programming mode")\r
-                       ret = self.serial\r
-                       self.serial = None\r
-                       return ret\r
-               return None\r
-       \r
-       def isConnected(self):\r
-               return self.serial != None\r
-       \r
-       def sendISP(self, data):\r
-               recv = self.sendMessage([0x1D, 4, 4, 0, data[0], data[1], data[2], data[3]])\r
-               return recv[2:6]\r
-       \r
-       def writeFlash(self, flashData):\r
-               #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension\r
-               pageSize = self.chip['pageSize'] * 2\r
-               flashSize = pageSize * self.chip['pageCount']\r
-               if flashSize > 0xFFFF:\r
-                       self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])\r
-               else:\r
-                       self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])\r
-               \r
-               loadCount = (len(flashData) + pageSize - 1) / pageSize\r
-               for i in xrange(0, loadCount):\r
-                       recv = self.sendMessage([0x13, pageSize >> 8, pageSize & 0xFF, 0xc1, 0x0a, 0x40, 0x4c, 0x20, 0x00, 0x00] + flashData[(i * pageSize):(i * pageSize + pageSize)])\r
-                       if self.progressCallback != None:\r
-                               self.progressCallback(i + 1, loadCount*2)\r
-       \r
-       def verifyFlash(self, flashData):\r
-               #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension\r
-               flashSize = self.chip['pageSize'] * 2 * self.chip['pageCount']\r
-               if flashSize > 0xFFFF:\r
-                       self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])\r
-               else:\r
-                       self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])\r
-               \r
-               loadCount = (len(flashData) + 0xFF) / 0x100\r
-               for i in xrange(0, loadCount):\r
-                       recv = self.sendMessage([0x14, 0x01, 0x00, 0x20])[2:0x102]\r
-                       if self.progressCallback != None:\r
-                               self.progressCallback(loadCount + i + 1, loadCount*2)\r
-                       for j in xrange(0, 0x100):\r
-                               if i * 0x100 + j < len(flashData) and flashData[i * 0x100 + j] != recv[j]:\r
-                                       raise ispBase.IspError('Verify error at: 0x%x' % (i * 0x100 + j))\r
-\r
-       def sendMessage(self, data):\r
-               message = struct.pack(">BBHB", 0x1B, self.seq, len(data), 0x0E)\r
-               for c in data:\r
-                       message += struct.pack(">B", c)\r
-               checksum = 0\r
-               for c in message:\r
-                       checksum ^= ord(c)\r
-               message += struct.pack(">B", checksum)\r
-               try:\r
-                       self.serial.write(message)\r
-                       self.serial.flush()\r
-               except SerialTimeoutException:\r
-                       raise ispBase.IspError('Serial send timeout')\r
-               self.seq = (self.seq + 1) & 0xFF\r
-               return self.recvMessage()\r
-       \r
-       def recvMessage(self):\r
-               state = 'Start'\r
-               checksum = 0\r
-               while True:\r
-                       s = self.serial.read()\r
-                       if len(s) < 1:\r
-                               raise ispBase.IspError("Timeout")\r
-                       b = struct.unpack(">B", s)[0]\r
-                       checksum ^= b\r
-                       #print(hex(b))\r
-                       if state == 'Start':\r
-                               if b == 0x1B:\r
-                                       state = 'GetSeq'\r
-                                       checksum = 0x1B\r
-                       elif state == 'GetSeq':\r
-                               state = 'MsgSize1'\r
-                       elif state == 'MsgSize1':\r
-                               msgSize = b << 8\r
-                               state = 'MsgSize2'\r
-                       elif state == 'MsgSize2':\r
-                               msgSize |= b\r
-                               state = 'Token'\r
-                       elif state == 'Token':\r
-                               if b != 0x0E:\r
-                                       state = 'Start'\r
-                               else:\r
-                                       state = 'Data'\r
-                                       data = []\r
-                       elif state == 'Data':\r
-                               data.append(b)\r
-                               if len(data) == msgSize:\r
-                                       state = 'Checksum'\r
-                       elif state == 'Checksum':\r
-                               if checksum != 0:\r
-                                       state = 'Start'\r
-                               else:\r
-                                       return data\r
-\r
-\r
-def main():\r
-       programmer = Stk500v2()\r
-       programmer.connect()\r
-       programmer.programChip(intelHex.readHex(sys.argv[1]))\r
-       sys.exit(1)\r
-\r
-if __name__ == '__main__':\r
-       main()\r
+"""
+STK500v2 protocol implementation for programming AVR chips.
+The STK500v2 protocol is used by the ArduinoMega2560 and a few other Arduino platforms to load firmware.
+"""
+__copyright__ = "Copyright (C) 2013 David Braam - Released under terms of the AGPLv3 License"
+import os, struct, sys, time
+
+from serial import Serial
+from serial import SerialException
+
+import ispBase, intelHex
+
+class Stk500v2(ispBase.IspBase):
+       def __init__(self):
+               self.serial = None
+               self.seq = 1
+               self.lastAddr = -1
+               self.progressCallback = None
+       
+       def connect(self, port = 'COM22', speed = 115200):
+               if self.serial is not None:
+                       self.close()
+               try:
+                       self.serial = Serial(str(port), speed, timeout=1, writeTimeout=10000)
+               except SerialException as e:
+                       raise ispBase.IspError("Failed to open serial port")
+               except:
+                       raise ispBase.IspError("Unexpected error while connecting to serial port:" + port + ":" + str(sys.exc_info()[0]))
+               self.seq = 1
+               
+               #Reset the controller
+               self.serial.setDTR(1)
+               time.sleep(0.1)
+               self.serial.setDTR(0)
+               time.sleep(0.2)
+
+               self.sendMessage([1])
+               if self.sendMessage([0x10, 0xc8, 0x64, 0x19, 0x20, 0x00, 0x53, 0x03, 0xac, 0x53, 0x00, 0x00]) != [0x10, 0x00]:
+                       self.close()
+                       raise ispBase.IspError("Failed to enter programming mode")
+               self.serial.timeout = 5
+
+       def close(self):
+               if self.serial is not None:
+                       self.serial.close()
+                       self.serial = None
+
+       #Leave ISP does not reset the serial port, only resets the device, and returns the serial port after disconnecting it from the programming interface.
+       #       This allows you to use the serial port without opening it again.
+       def leaveISP(self):
+               if self.serial is not None:
+                       if self.sendMessage([0x11]) != [0x11, 0x00]:
+                               raise ispBase.IspError("Failed to leave programming mode")
+                       ret = self.serial
+                       self.serial = None
+                       return ret
+               return None
+       
+       def isConnected(self):
+               return self.serial is not None
+       
+       def sendISP(self, data):
+               recv = self.sendMessage([0x1D, 4, 4, 0, data[0], data[1], data[2], data[3]])
+               return recv[2:6]
+       
+       def writeFlash(self, flashData):
+               #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension
+               pageSize = self.chip['pageSize'] * 2
+               flashSize = pageSize * self.chip['pageCount']
+               if flashSize > 0xFFFF:
+                       self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
+               else:
+                       self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])
+               
+               loadCount = (len(flashData) + pageSize - 1) / pageSize
+               for i in xrange(0, loadCount):
+                       recv = self.sendMessage([0x13, pageSize >> 8, pageSize & 0xFF, 0xc1, 0x0a, 0x40, 0x4c, 0x20, 0x00, 0x00] + flashData[(i * pageSize):(i * pageSize + pageSize)])
+                       if self.progressCallback != None:
+                               self.progressCallback(i + 1, loadCount*2)
+       
+       def verifyFlash(self, flashData):
+               #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension
+               flashSize = self.chip['pageSize'] * 2 * self.chip['pageCount']
+               if flashSize > 0xFFFF:
+                       self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
+               else:
+                       self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])
+               
+               loadCount = (len(flashData) + 0xFF) / 0x100
+               for i in xrange(0, loadCount):
+                       recv = self.sendMessage([0x14, 0x01, 0x00, 0x20])[2:0x102]
+                       if self.progressCallback != None:
+                               self.progressCallback(loadCount + i + 1, loadCount*2)
+                       for j in xrange(0, 0x100):
+                               if i * 0x100 + j < len(flashData) and flashData[i * 0x100 + j] != recv[j]:
+                                       raise ispBase.IspError('Verify error at: 0x%x' % (i * 0x100 + j))
+
+       def sendMessage(self, data):
+               message = struct.pack(">BBHB", 0x1B, self.seq, len(data), 0x0E)
+               for c in data:
+                       message += struct.pack(">B", c)
+               checksum = 0
+               for c in message:
+                       checksum ^= ord(c)
+               message += struct.pack(">B", checksum)
+               try:
+                       self.serial.write(message)
+                       self.serial.flush()
+               except Serial.SerialTimeoutException:
+                       raise ispBase.IspError('Serial send timeout')
+               self.seq = (self.seq + 1) & 0xFF
+               return self.recvMessage()
+       
+       def recvMessage(self):
+               state = 'Start'
+               checksum = 0
+               while True:
+                       s = self.serial.read()
+                       if len(s) < 1:
+                               raise ispBase.IspError("Timeout")
+                       b = struct.unpack(">B", s)[0]
+                       checksum ^= b
+                       #print(hex(b))
+                       if state == 'Start':
+                               if b == 0x1B:
+                                       state = 'GetSeq'
+                                       checksum = 0x1B
+                       elif state == 'GetSeq':
+                               state = 'MsgSize1'
+                       elif state == 'MsgSize1':
+                               msgSize = b << 8
+                               state = 'MsgSize2'
+                       elif state == 'MsgSize2':
+                               msgSize |= b
+                               state = 'Token'
+                       elif state == 'Token':
+                               if b != 0x0E:
+                                       state = 'Start'
+                               else:
+                                       state = 'Data'
+                                       data = []
+                       elif state == 'Data':
+                               data.append(b)
+                               if len(data) == msgSize:
+                                       state = 'Checksum'
+                       elif state == 'Checksum':
+                               if checksum != 0:
+                                       state = 'Start'
+                               else:
+                                       return data
+
+def portList():
+       ret = []
+       import _winreg
+       key=_winreg.OpenKey(_winreg.HKEY_LOCAL_MACHINE,"HARDWARE\\DEVICEMAP\\SERIALCOMM")
+       i=0
+       while True:
+               try:
+                       values = _winreg.EnumValue(key, i)
+               except:
+                       return ret
+               if 'USBSER' in values[0]:
+                       ret.append(values[1])
+               i+=1
+       return ret
+
+def runProgrammer(port, filename):
+       programmer = Stk500v2()
+       programmer.connect(port = port)
+       programmer.programChip(intelHex.readHex(filename))
+       programmer.close()
+
+def main():
+       import threading
+       if sys.argv[1] == 'AUTO':
+               print portList()
+               for port in portList():
+                       threading.Thread(target=runProgrammer, args=(port,sys.argv[2])).start()
+                       time.sleep(5)
+       else:
+               programmer = Stk500v2()
+               programmer.connect(port = sys.argv[1])
+               programmer.programChip(intelHex.readHex(sys.argv[2]))
+               sys.exit(1)
+
+if __name__ == '__main__':
+       main()