chiark / gitweb /
Add some more documentation. And the zhop feature.
[cura.git] / Cura / avr_isp / stk500v2.py
1 """
2 STK500v2 protocol implementation for programming AVR chips.
3 The STK500v2 protocol is used by the ArduinoMega2560 and a few other Arduino platforms to load firmware.
4 """
5 __copyright__ = "Copyright (C) 2013 David Braam - Released under terms of the AGPLv3 License"
6 import os, struct, sys, time
7
8 from serial import Serial
9 from serial import SerialException
10
11 import ispBase, intelHex
12
13 class Stk500v2(ispBase.IspBase):
14         def __init__(self):
15                 self.serial = None
16                 self.seq = 1
17                 self.lastAddr = -1
18                 self.progressCallback = None
19         
20         def connect(self, port = 'COM22', speed = 115200):
21                 if self.serial is not None:
22                         self.close()
23                 try:
24                         self.serial = Serial(str(port), speed, timeout=1, writeTimeout=10000)
25                 except SerialException as e:
26                         raise ispBase.IspError("Failed to open serial port")
27                 except:
28                         raise ispBase.IspError("Unexpected error while connecting to serial port:" + port + ":" + str(sys.exc_info()[0]))
29                 self.seq = 1
30                 
31                 #Reset the controller
32                 self.serial.setDTR(1)
33                 time.sleep(0.1)
34                 self.serial.setDTR(0)
35                 time.sleep(0.2)
36
37                 self.sendMessage([1])
38                 if self.sendMessage([0x10, 0xc8, 0x64, 0x19, 0x20, 0x00, 0x53, 0x03, 0xac, 0x53, 0x00, 0x00]) != [0x10, 0x00]:
39                         self.close()
40                         raise ispBase.IspError("Failed to enter programming mode")
41                 self.serial.timeout = 5
42
43         def close(self):
44                 if self.serial is not None:
45                         self.serial.close()
46                         self.serial = None
47
48         #Leave ISP does not reset the serial port, only resets the device, and returns the serial port after disconnecting it from the programming interface.
49         #       This allows you to use the serial port without opening it again.
50         def leaveISP(self):
51                 if self.serial is not None:
52                         if self.sendMessage([0x11]) != [0x11, 0x00]:
53                                 raise ispBase.IspError("Failed to leave programming mode")
54                         ret = self.serial
55                         self.serial = None
56                         return ret
57                 return None
58         
59         def isConnected(self):
60                 return self.serial is not None
61         
62         def sendISP(self, data):
63                 recv = self.sendMessage([0x1D, 4, 4, 0, data[0], data[1], data[2], data[3]])
64                 return recv[2:6]
65         
66         def writeFlash(self, flashData):
67                 #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension
68                 pageSize = self.chip['pageSize'] * 2
69                 flashSize = pageSize * self.chip['pageCount']
70                 if flashSize > 0xFFFF:
71                         self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
72                 else:
73                         self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])
74                 
75                 loadCount = (len(flashData) + pageSize - 1) / pageSize
76                 for i in xrange(0, loadCount):
77                         recv = self.sendMessage([0x13, pageSize >> 8, pageSize & 0xFF, 0xc1, 0x0a, 0x40, 0x4c, 0x20, 0x00, 0x00] + flashData[(i * pageSize):(i * pageSize + pageSize)])
78                         if self.progressCallback != None:
79                                 self.progressCallback(i + 1, loadCount*2)
80         
81         def verifyFlash(self, flashData):
82                 #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension
83                 flashSize = self.chip['pageSize'] * 2 * self.chip['pageCount']
84                 if flashSize > 0xFFFF:
85                         self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
86                 else:
87                         self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])
88                 
89                 loadCount = (len(flashData) + 0xFF) / 0x100
90                 for i in xrange(0, loadCount):
91                         recv = self.sendMessage([0x14, 0x01, 0x00, 0x20])[2:0x102]
92                         if self.progressCallback != None:
93                                 self.progressCallback(loadCount + i + 1, loadCount*2)
94                         for j in xrange(0, 0x100):
95                                 if i * 0x100 + j < len(flashData) and flashData[i * 0x100 + j] != recv[j]:
96                                         raise ispBase.IspError('Verify error at: 0x%x' % (i * 0x100 + j))
97
98         def sendMessage(self, data):
99                 message = struct.pack(">BBHB", 0x1B, self.seq, len(data), 0x0E)
100                 for c in data:
101                         message += struct.pack(">B", c)
102                 checksum = 0
103                 for c in message:
104                         checksum ^= ord(c)
105                 message += struct.pack(">B", checksum)
106                 try:
107                         self.serial.write(message)
108                         self.serial.flush()
109                 except Serial.SerialTimeoutException:
110                         raise ispBase.IspError('Serial send timeout')
111                 self.seq = (self.seq + 1) & 0xFF
112                 return self.recvMessage()
113         
114         def recvMessage(self):
115                 state = 'Start'
116                 checksum = 0
117                 while True:
118                         s = self.serial.read()
119                         if len(s) < 1:
120                                 raise ispBase.IspError("Timeout")
121                         b = struct.unpack(">B", s)[0]
122                         checksum ^= b
123                         #print(hex(b))
124                         if state == 'Start':
125                                 if b == 0x1B:
126                                         state = 'GetSeq'
127                                         checksum = 0x1B
128                         elif state == 'GetSeq':
129                                 state = 'MsgSize1'
130                         elif state == 'MsgSize1':
131                                 msgSize = b << 8
132                                 state = 'MsgSize2'
133                         elif state == 'MsgSize2':
134                                 msgSize |= b
135                                 state = 'Token'
136                         elif state == 'Token':
137                                 if b != 0x0E:
138                                         state = 'Start'
139                                 else:
140                                         state = 'Data'
141                                         data = []
142                         elif state == 'Data':
143                                 data.append(b)
144                                 if len(data) == msgSize:
145                                         state = 'Checksum'
146                         elif state == 'Checksum':
147                                 if checksum != 0:
148                                         state = 'Start'
149                                 else:
150                                         return data
151
152 def portList():
153         ret = []
154         import _winreg
155         key=_winreg.OpenKey(_winreg.HKEY_LOCAL_MACHINE,"HARDWARE\\DEVICEMAP\\SERIALCOMM")
156         i=0
157         while True:
158                 try:
159                         values = _winreg.EnumValue(key, i)
160                 except:
161                         return ret
162                 if 'USBSER' in values[0]:
163                         ret.append(values[1])
164                 i+=1
165         return ret
166
167 def runProgrammer(port, filename):
168         programmer = Stk500v2()
169         programmer.connect(port = port)
170         programmer.programChip(intelHex.readHex(filename))
171         programmer.close()
172
173 def main():
174         import threading
175         if sys.argv[1] == 'AUTO':
176                 print portList()
177                 for port in portList():
178                         threading.Thread(target=runProgrammer, args=(port,sys.argv[2])).start()
179                         time.sleep(5)
180         else:
181                 programmer = Stk500v2()
182                 programmer.connect(port = sys.argv[1])
183                 programmer.programChip(intelHex.readHex(sys.argv[2]))
184                 sys.exit(1)
185
186 if __name__ == '__main__':
187         main()