chiark / gitweb /
Fix pyserial 3.0 compatibility issue
[cura.git] / Cura / avr_isp / stk500v2.py
1 """
2 STK500v2 protocol implementation for programming AVR chips.
3 The STK500v2 protocol is used by the ArduinoMega2560 and a few other Arduino platforms to load firmware.
4 """
5 __copyright__ = "Copyright (C) 2013 David Braam - Released under terms of the AGPLv3 License"
6 import os, struct, sys, time
7
8 from serial import Serial
9 from serial import SerialException
10 from serial import SerialTimeoutException
11
12 import ispBase, intelHex
13
14 class Stk500v2(ispBase.IspBase):
15         def __init__(self):
16                 self.serial = None
17                 self.seq = 1
18                 self.lastAddr = -1
19                 self.progressCallback = None
20         
21         def connect(self, port = 'COM22', speed = 115200):
22                 if self.serial is not None:
23                         self.close()
24                 try:
25                         self.serial = Serial(str(port), speed, timeout=1)
26                         # Need to set writeTimeout separately in order to be compatible with pyserial 3.0
27                         self.serial.writeTimeout=10000
28                 except SerialException as e:
29                         raise ispBase.IspError("Failed to open serial port")
30                 except:
31                         raise ispBase.IspError("Unexpected error while connecting to serial port:" + port + ":" + str(sys.exc_info()[0]))
32                 self.seq = 1
33
34                 #Reset the controller
35                 for n in xrange(0, 2):
36                         self.serial.setDTR(True)
37                         time.sleep(0.1)
38                         self.serial.setDTR(False)
39                         time.sleep(0.1)
40                 time.sleep(0.2)
41
42                 self.serial.flushInput()
43                 self.serial.flushOutput()
44                 self.sendMessage([1])
45                 if self.sendMessage([0x10, 0xc8, 0x64, 0x19, 0x20, 0x00, 0x53, 0x03, 0xac, 0x53, 0x00, 0x00]) != [0x10, 0x00]:
46                         self.close()
47                         raise ispBase.IspError("Failed to enter programming mode")
48
49                 self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
50                 if self.sendMessage([0xEE])[1] == 0x00:
51                         self._has_checksum = True
52                 else:
53                         self._has_checksum = False
54                 self.serial.timeout = 5
55
56         def close(self):
57                 if self.serial is not None:
58                         self.serial.close()
59                         self.serial = None
60
61         #Leave ISP does not reset the serial port, only resets the device, and returns the serial port after disconnecting it from the programming interface.
62         #       This allows you to use the serial port without opening it again.
63         def leaveISP(self):
64                 if self.serial is not None:
65                         if self.sendMessage([0x11]) != [0x11, 0x00]:
66                                 raise ispBase.IspError("Failed to leave programming mode")
67                         ret = self.serial
68                         self.serial = None
69                         return ret
70                 return None
71         
72         def isConnected(self):
73                 return self.serial is not None
74
75         def hasChecksumFunction(self):
76                 return self._has_checksum
77
78         def sendISP(self, data):
79                 recv = self.sendMessage([0x1D, 4, 4, 0, data[0], data[1], data[2], data[3]])
80                 return recv[2:6]
81         
82         def writeFlash(self, flashData):
83                 #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension
84                 pageSize = self.chip['pageSize'] * 2
85                 flashSize = pageSize * self.chip['pageCount']
86                 if flashSize > 0xFFFF:
87                         self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
88                 else:
89                         self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])
90                 
91                 loadCount = (len(flashData) + pageSize - 1) / pageSize
92                 for i in xrange(0, loadCount):
93                         recv = self.sendMessage([0x13, pageSize >> 8, pageSize & 0xFF, 0xc1, 0x0a, 0x40, 0x4c, 0x20, 0x00, 0x00] + flashData[(i * pageSize):(i * pageSize + pageSize)])
94                         if self.progressCallback is not None:
95                                 if self._has_checksum:
96                                         self.progressCallback(i + 1, loadCount)
97                                 else:
98                                         self.progressCallback(i + 1, loadCount*2)
99         
100         def verifyFlash(self, flashData):
101                 if self._has_checksum:
102                         self.sendMessage([0x06, 0x00, (len(flashData) >> 17) & 0xFF, (len(flashData) >> 9) & 0xFF, (len(flashData) >> 1) & 0xFF])
103                         res = self.sendMessage([0xEE])
104                         checksum_recv = res[2] | (res[3] << 8)
105                         checksum = 0
106                         for d in flashData:
107                                 checksum += d
108                         checksum &= 0xFFFF
109                         if hex(checksum) != hex(checksum_recv):
110                                 raise ispBase.IspError('Verify checksum mismatch: 0x%x != 0x%x' % (checksum & 0xFFFF, checksum_recv))
111                 else:
112                         #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension
113                         flashSize = self.chip['pageSize'] * 2 * self.chip['pageCount']
114                         if flashSize > 0xFFFF:
115                                 self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
116                         else:
117                                 self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])
118
119                         loadCount = (len(flashData) + 0xFF) / 0x100
120                         for i in xrange(0, loadCount):
121                                 recv = self.sendMessage([0x14, 0x01, 0x00, 0x20])[2:0x102]
122                                 if self.progressCallback is not None:
123                                         self.progressCallback(loadCount + i + 1, loadCount*2)
124                                 for j in xrange(0, 0x100):
125                                         if i * 0x100 + j < len(flashData) and flashData[i * 0x100 + j] != recv[j]:
126                                                 raise ispBase.IspError('Verify error at: 0x%x' % (i * 0x100 + j))
127
128         def sendMessage(self, data):
129                 message = struct.pack(">BBHB", 0x1B, self.seq, len(data), 0x0E)
130                 for c in data:
131                         message += struct.pack(">B", c)
132                 checksum = 0
133                 for c in message:
134                         checksum ^= ord(c)
135                 message += struct.pack(">B", checksum)
136                 try:
137                         self.serial.write(message)
138                         self.serial.flush()
139                 except SerialTimeoutException:
140                         raise ispBase.IspError('Serial send timeout')
141                 self.seq = (self.seq + 1) & 0xFF
142                 return self.recvMessage()
143         
144         def recvMessage(self):
145                 state = 'Start'
146                 checksum = 0
147                 while True:
148                         s = self.serial.read()
149                         if len(s) < 1:
150                                 raise ispBase.IspError("Timeout")
151                         b = struct.unpack(">B", s)[0]
152                         checksum ^= b
153                         #print(hex(b))
154                         if state == 'Start':
155                                 if b == 0x1B:
156                                         state = 'GetSeq'
157                                         checksum = 0x1B
158                         elif state == 'GetSeq':
159                                 state = 'MsgSize1'
160                         elif state == 'MsgSize1':
161                                 msgSize = b << 8
162                                 state = 'MsgSize2'
163                         elif state == 'MsgSize2':
164                                 msgSize |= b
165                                 state = 'Token'
166                         elif state == 'Token':
167                                 if b != 0x0E:
168                                         state = 'Start'
169                                 else:
170                                         state = 'Data'
171                                         data = []
172                         elif state == 'Data':
173                                 data.append(b)
174                                 if len(data) == msgSize:
175                                         state = 'Checksum'
176                         elif state == 'Checksum':
177                                 if checksum != 0:
178                                         state = 'Start'
179                                 else:
180                                         return data
181
182 def portList():
183         ret = []
184         import _winreg
185         key=_winreg.OpenKey(_winreg.HKEY_LOCAL_MACHINE,"HARDWARE\\DEVICEMAP\\SERIALCOMM")
186         i=0
187         while True:
188                 try:
189                         values = _winreg.EnumValue(key, i)
190                 except:
191                         return ret
192                 if 'USBSER' in values[0]:
193                         ret.append(values[1])
194                 i+=1
195         return ret
196
197 def runProgrammer(port, filename):
198         """ Run an STK500v2 program on serial port 'port' and write 'filename' into flash. """
199         programmer = Stk500v2()
200         programmer.connect(port = port)
201         programmer.programChip(intelHex.readHex(filename))
202         programmer.close()
203
204 def main():
205         """ Entry point to call the stk500v2 programmer from the commandline. """
206         import threading
207         if sys.argv[1] == 'AUTO':
208                 print portList()
209                 for port in portList():
210                         threading.Thread(target=runProgrammer, args=(port,sys.argv[2])).start()
211                         time.sleep(5)
212         else:
213                 programmer = Stk500v2()
214                 programmer.connect(port = sys.argv[1])
215                 programmer.programChip(intelHex.readHex(sys.argv[2]))
216                 sys.exit(1)
217
218 if __name__ == '__main__':
219         main()