Package viterbi :: Module viterbi_search
[hide private]
[frames] | no frames]

Source Code for Module viterbi.viterbi_search

  1  #!/usr/bin/env  python 
  2   
  3  """ 
  4  Code implementing a Python version of the Viterbi algorithm for 
  5  HMMs, which computes the lowest probability path through the 
  6  HMM, given a sequence of observations. 
  7   
  8  Path probabilities through the HMM are computed using TWO probability 
  9  models, implemented as dictionaries. The first is a B{transition 
 10  probability model} (C{A}) and the second is the B{emission probability 
 11  model} (C{B}). The transition prob:: 
 12         A[s1][s2] 
 13  returns the probability of transitioning to s2 given that you are at 
 14  state s1 (Pr(s2|s1)).  The emission prob:: 
 15        B[o][s] 
 16  is the probability of observing a string C{o} given you are at state 
 17  C{s} (Pr(o|s)).  Note C{A} and C{B} are dictionaries of 
 18  dictionaries but C{A} is a dictionary with state keys, and C{B} is a 
 19  dictionary with observation keys. Symbols not in the HMM alphabet 
 20  will produce errors. 
 21   
 22  To run the decoder: 
 23   
 24     >>> decode_line('313',A3, B3, states3, start3) 
 25   
 26  This runs with Machine 3, the ice cream/weather machine of 
 27  Chapter 6, Jurafsky & Martin, with the input diagrammed 
 28  in Fig 6.10.  The first argument is the input and the next 
 29  four define the HMM:: 
 30      -- A3:  the state transition probs 
 31      -- B3:  the observation probs 
 32      -- states3: the set of states of the machine 
 33      -- start3 the start state 
 34   
 35  An optional 5th Boolean argument toggles HMM emission prob 
 36  predicting.  With predicting turned on (the default), emission probs 
 37  for the state one is coming FROM are used, which means one had to 
 38  transition to that state before seeing the emission that mattered 
 39  there. So one had to"predict" right 
 40  (i.e., guess).  The default value for this argument is True. 
 41  The only machine defined in this file that runs with 
 42  predicting turned off is Machine 1: 
 43   
 44      >>> decode_line('httt',A1, B1, states1, start1, True) 
 45   
 46  With predicting turned off, emission probs for the state one is going 
 47  to are used, which means one sees the emission probs and the emission 
 48  at the same time, and no predicting is necessary to choose the state 
 49  with the highest emission prob for the current emission. 
 50   
 51  This file includes three example HMMs.  The first is a simple coin 
 52  flipping HMM, which, with predicting turned on, tries to "guess" the result of 
 53  the next coin flip in a sequence of coin flips by transitioning 
 54  at each observation to either a head-predicting state ('H') or 
 55  a tail-predicting state ('T').  The observation alphabet 
 56  consists of the lower case letters 'h' and 't'. 
 57   
 58  Machine 2 is the machine used for the word decoder in 
 59  the first edition of U{Jurafsky and Martin, Ch.7, 
 60  Fig. 9<http:www-rohan.sdsu.edu/~gawron/compling/chap7/fig07.09.pdf>}, 
 61  which gives path probilities for decoding a sequence of phones as a 
 62  sequence of words.  A state C{state} is a pair of a word and a phone. 
 63  The emission probs in C{B2} are always either 1.0 or 0.0, 1.0 for any 
 64  input phone that matches the phone that state represents, C{state[0]}, 
 65  0 for all others. For example, state C{('dh','the')} has probability 
 66  1.0 for"dh" 
 67  and probability 0 for all other phones.The transition probs, C{A2}, are 
 68  of the form:: 
 69      A2[phone1][phone2] 
 70       
 71  and represent the probability of phone2 given phone1. It includes 
 72  cross-word transitions.  These are computed by combining the 
 73  probability models in L{viterbi.pron_probs} and 
 74  L{viterbi.bigram_probs}. 
 75   
 76   
 77  @variable A1: Transition probs for HMM 1,  
 78                U{the head-tails machine <http://www-rohan.sdsu.edu/~gawron/compling/course_core/lectures/coin_tosser.pdf>}. Transitions to start state always have prob 0. Transitions               to either of the other state always have prob 0.5 
 79  @variable B1: Observation probs for HMM 1,  
 80                in U{the head-tails machine<http://www-rohan.sdsu.edu/~gawron/compling/course_core/lectures/coin_tosser.pdf>}. Start state has unbiased observation probs. 
 81                The other 2 have observation probs 1.0 for 
 82                heads and tails respectively. 
 83  @variable start1: Start state for HMM 1,  
 84                U{the head-tails machine<http://www-rohan.sdsu.edu/~gawron/compling/course_core/lectures/coin_tosser.pdf>}.  
 85  @variable states1: States list for HMM 1,  
 86                U{the head-tails machine<http://www-rohan.sdsu.edu/~gawron/compling/course_core/lectures/coin_tosser.pdf>}. 
 87  @variable start2: The starting state for HMM 2, which is the HMM implicit in 
 88                     U{J&M, Figure 7.10<http:www-rohan.sdsu.edu/~gawron/compling/chap7/fig07.10.pdf>}. 
 89  @variable states2: The states for HMM 2, which is the HMM implicit in 
 90                     U{J&M, Figure 7.10<http:www-rohan.sdsu.edu/~gawron/compling/chap7/fig07.10.pdf>}. 
 91  """ 
 92   
 93  neg_infinity=float("-infinity") 
 94   
 95   
 96  ###################################################################### 
 97  ###################################################################### 
 98  ###  M a i n      P r o g r a m 
 99  ###################################################################### 
100  ###################################################################### 
101   
102  back = [] 
103  trellis = [] 
104 -def decode (Stringseq, A, B, states, s0, predict=False):
105 """ 106 Takes a list of strings, C{Stringseq}, each element 107 a word according to the HMM defined by C{A},C{B}, C{states}, 108 and C{s0}. Returns C{path}, a list of states of length len(Stringseq)+1, 109 representing the highest prob HMM path that accepts Stringseq, as well 110 as the two tables built in the Viterbi computation. Note: 111 implements the pseudocode of U{J&M, Figure 7.09<http:www-rohan.sdsu.edu/~gawron/compling/chap7/fig07.09.pdf>}. 112 113 C{trellis} is the table the Viterbi algorithm fills in. 114 Therefore, C{trellis} will be an array of length T+1 (length of input + t=0) 115 Each element will be a dictionary with states as keys. 116 trellis[t][s] is the viterbi score of state s at time t [a log prob] 117 118 C{back} (for "backtrace") will be an array with the same dimensions:: 119 120 back[t][s] 121 122 is the best state to have come FROM to get to s at t. 123 124 @param Stringseq: a list of strings representin the sequence input 125 observations 126 @param A: Transition prob table of HMM 127 @param B: Observation prob table of HMM 128 @param states: A list of states for HMM 129 @param s0: Start state of HMM 130 @rtype: a 4-tuple of C{trellis} (the Viterbi table), C{back} (the backtrace 131 table), the best probability path through the HMM (list), and 132 C{nice_path_string} (a pretty string representation of the best 133 path, suitable for printing). 134 """ 135 global trellis, back 136 T = len(Stringseq) 137 # print 'T: %s' % T 138 trellis = [] 139 back=[] 140 ############################################################################ 141 # Initialize trellis and back. 142 # 143 ############################################################################ 144 for t in xrange(T+1): # initialize trellis and back 145 # Use xrange rather than range: Efficiency 146 viterbi_scores = dict(zip(states,[neg_infinity]*len(states))) # viterbi scores for this t; 147 # start with easy to beat neg log probs (= low probs) 148 back_states = dict(zip(states,['init']*len(states))) # viterbi states (states to have come from) for this t 149 if t==0: ## start state is the state you have to be in at t=0 150 viterbi_scores[s0] = 0.0 ## log prob= 0 implies prob = 1 151 back_states[s0] = 'init' ## placeholder for debugging 152 trellis.append(viterbi_scores) 153 back.append(back_states) 154 ############################################################################ 155 # The main body of the viterbi algorithm 156 # Fill in trellis with Viterbi values, back with backpointers 157 ############################################################################ 158 for t in xrange(1,T+1): 159 o = Stringseq[t-1] # o is the current observtaion. 160 # print 't: %s' % t 161 # print 'o: %s' % o 162 try: 163 emission_probs=B[o] 164 except KeyError: 165 print 'Illegal input: %s' % o 166 return (trellis,back,B,'') 167 # Fill in next column; using log probs so add to get score 168 for s_to in states: 169 for s_from in states: 170 # In s at t, coming from s_from at t-1 171 if predict: 172 e_state = s_from 173 else: 174 e_state = s_to 175 try: 176 # print_info(s_from,s_to,A,emission_probs,trellis) 177 score=trellis[t-1][s_from]+ A[s_from][s_to]+ emission_probs[e_state] 178 except KeyError: 179 print 'Key error %s %s %s %s' % (s_from,s_to,o, t) 180 return (trellis,back,[],'') 181 # print ' score: %s' % score 182 # print ' trellis[t][s_to]: %s' % trellis[t][s_to] 183 # print ' trellis[t-1][s_from]: %s' % trellis[t-1][s_from] 184 if score > trellis[t][s_to]: 185 trellis[t][s_to]=score 186 back[t][s_to]=s_from 187 else: continue 188 189 ############################################################################ 190 # End of main body of the viterbi algorithm 191 # 192 ############################################################################ 193 # Find best state for final piece of input at t=T 194 best=s0 # initial value: arbitrary 195 for s in states: 196 if trellis[T][s] > trellis[T][best]: 197 best=s 198 else: continue 199 path=[best] 200 # nice_path=[nice_names[best]] ## Use for debugging and display 201 nice_path=[str(best)] # Not all state names are strings. Make sure. 202 for t in xrange(T,0,-1): # count backwards (T ... 1) 203 best=back[t][best] 204 path[0:0]=[best] # Python idiom for "push" 205 nice_path[0:0]=[str(best), # nice_names[best], 206 '--%s-->' % (Stringseq[t-1], )] 207 nice_path_string = ' '.join(nice_path) # Make a string consisting of the elements of list nice_path 208 # separated by ' ' (space) 209 # called as a method on the string ' '. 210 return (trellis,back,path,nice_path_string)
211 212 213 ########################################################### 214 ### Utilities 215 ########################################################### 216
217 -def make_transition_dict(bigram_prob_dict,pron_prob_dict,states):
218 """ 219 Turn dictionaries representing bigrams probs and 220 pronunciation models into C{transition_dict}, 221 an HMM transition function such that:: 222 223 transition_dict[(phone_from,word_from)][(phone_to,word_to)] = prob 224 225 This dictionary represents the kind of global phone to 226 phone transition probs shown in U{J&M, Figure 7.08<http://www-rohan.sdsu.edu/~gawron/compling/chap7/fig07.08.pdf>}, including cross-word transitions. 227 228 This transition probability across words is estimated as:: 229 230 Pr(end | word_from,phone_from) * 231 Pr(phone_to | start, word_to) * 232 Pr(word_to | word_from) 233 234 or equivalently:: 235 236 pron_prob_dict[word_from][phone_from][end] * 237 pron_prob_dict[word_to][start][phone_to] * 238 bigram_prob_dict[word_from][word_to] 239 240 Thus the resulting HMM transition table has no start word and 241 end word states. 242 243 @param bigram_prob_dict: A dictionary of dictionaries representing 244 bigram probs. 245 @param pron_prob_dict: A dictionary of dictionaries representing 246 with word keys representing bigram pronunication 247 probs. 248 @param states: the list of states for the current HMM. 249 @rtype: dictionary of dictionaries 250 """ 251 transition_dict = {} 252 for (phone_from, word_from) in states: 253 to_dict = transition_dict.setdefault((phone_from,word_from),{}) 254 # print (phone_from, word_from) 255 for (phone_to, word_to) in states: 256 # print ' ',(phone_to, word_to) 257 if word_from == word_to: 258 if phone_to == phone_from: 259 prob = 0.0 260 else: 261 try: 262 prob = float(pron_prob_dict[word_from][phone_from][phone_to]) 263 except KeyError: 264 prob = 0.0 265 else: 266 try: 267 if (phone_from,word_from) == ('start','start'): 268 ending_prob = 1.0 269 else: 270 ending_prob = float(pron_prob_dict[word_from][phone_from]['end']) 271 starting_prob = float(pron_prob_dict[word_to]['start'][phone_to]) 272 bigram_prob = float(bigram_prob_dict[word_from][word_to]) 273 prob = ending_prob * starting_prob * bigram_prob 274 # print ending_prob, starting_prob, bigram_prob, prob 275 except KeyError: 276 prob = 0.0 277 to_dict[(phone_to,word_to)] = prob 278 return transition_dict
279 280 ########################################################## 281 ## Introducing log probs 282 ########################################################## 283 284 import math 285
286 -def switch_transition_function_to_log_probs(A, states):
287 # switch A to log probs 288 for start in states: 289 for end in states: 290 if A[start][end] > 0: 291 A[start][end]=math.log(A[start][end],2) # use log base 2 292 else: A[start][end]= neg_infinity
293
294 -def switch_emission_function_to_log_probs(B, states):
295 # switch B to log probs 296 for word in B.keys(): 297 for state in states: 298 if B[word][state] > 0: 299 B[word][state] = math.log(B[word][state],2) 300 else: 301 B[word][state] = neg_infinity
302 333 341 342 ###################################################################### 343 ###################################################################### 344 ### M a c h i n e O n e (Coin tossing) 345 ###################################################################### 346 ###################################################################### 347 348 start1 = 's0' # start state name 349 # States 350 states1 = ['s0', 'H', 'T'] 351 352 # transition probs A[FromState][ToState] 353 A1 = {'s0': { 's0': 0.0, 354 'H': 0.5, 355 'T': 0.5}, 356 'H': {'s0': 0.0, 357 'H': 0.5, 358 'T': 0.5}, 359 'T': {'s0': 0.0, 360 'H': 0.5, 361 'T': 0.5}} 362 # Emission probs B[Emission][State] 363 B1 = { 'h': {'s0': 0.5, 364 'H': 1.0, 365 'T': 0.0 }, 366 't': {'s0': 0.5, 367 'H': 0.0, 368 'T': 1.0 } 369 } 370 371 372 ## Now switch A to log_probs 373 switch_transition_function_to_log_probs(A1,states1) 374 375 ## Now switch B to log_probs 376 switch_emission_function_to_log_probs(B1,states1) 377 378 ###################################################################### 379 ###################################################################### 380 ### M a c h i n e T w o (Word decoder) 381 ###################################################################### 382 ###################################################################### 383 from viterbi.bigram_probs import * 384 from viterbi.pron_probs import * 385 386 start2 = ('start','start') # start state name 387 388 # States for word-decoding. For now to simplify we assume (falsely) that 389 # no phone ever occurs twice in a word. For our tiny 390 # vocab, this true. Thus states are uniquely identified 391 # by a pair of a phone and a word. 392 states2 = [('end','need'),('d','need'),('iy','need'),('n','need'), ('start','need'), # need 393 ('end','the'), ('ax','the'),('iy','the'),('n','the'),('dh','the'),('start','the'), # the 394 ('end','on'), ('n','on'),('aa','on'),('start','on'), # on 395 ('end','I'), ('ay','I'),('aa','I'), ('start','I'), # I 396 ('start','start')] 397 phones= set([]) 398 for state in states2: 399 if state[0] not in ['start','end']: 400 phones.add(state[0]) 401 B2 = {} 402 403 # Emission probs: B2 404 for phone in phones: 405 phonedic = B2.setdefault(phone,{}) 406 for state in states2: 407 if phone == state[0]: 408 phonedic[state] = 1.0 409 else: 410 phonedic[state] = 0.0 411 412 A2 = make_transition_dict(bigram_probs_dict,pron_probs_dict,states2) 413 414 ## Now switch A to log_probs 415 switch_transition_function_to_log_probs(A2,states2) 416 417 ## Now switch B to log_probs 418 switch_emission_function_to_log_probs(B2,states2) 419 420 ## filter out beginning and end word states, now snipped out of 421 ## A2. 422 states2 = [state for state in states2 if\ 423 state[0] <> 'end' and state[0] <> 'start' and\ 424 state[1] <> 'end' and state[1] <> 'start'] 425 ## But still need start state 426 states2.append(start2) 427 428 ###################################################################### 429 ###################################################################### 430 ### M a c h i n e T h r e e (Weather/Ice Cream) 431 ###################################################################### 432 ###################################################################### 433 434 435 start3 = 'start0' # start state name 436 # States 437 states3 = ['start0', 'HOT', 'COLD'] 438 439 # transition probs A[FromState][ToState] 440 A3 = {'start0': { 'start0': 0.0, 441 'HOT': 0.8, 442 'COLD': 0.2}, 443 'HOT': {'start0': 0.0, 444 'HOT': 0.7, 445 'COLD': 0.3}, 446 'COLD': {'start0': 0.0, 447 'HOT': 0.4, 448 'COLD': 0.6}} 449 # Emission probs B[Emission][State] 450 B3 = { '1': {'start0': 0.0, 451 'HOT': 0.2, 452 'COLD': 0.5 }, 453 '2': {'start0': 0.0, 454 'HOT': 0.4, 455 'COLD': 0.4 }, 456 '3': {'start0': 0.0, 457 'HOT': 0.4, 458 'COLD': 0.1 }, 459 } 460 461 ## Now switch A to log_probs 462 switch_transition_function_to_log_probs(A3,states3) 463 464 ## Now switch B to log_probs 465 switch_emission_function_to_log_probs(B3,states3) 466 467 468 ############################################################### 469 ### W r a p p e r C o d e 470 ############################################################### 471 472 Stringseq = ['aa','n','iy','dh','ax'] 473 474 # decode_line('httt',A1, B1, states1, start1, True) 475 # decode_line(Stringseq,A2, B2, states2, start2) 476 # print_viterbi_table(X[0],Stringseq,states2, start2,9) 477 # decode_line('313',A3, B3, states3, start3) 478
479 -def decode_line (line, A, B, states, start, predict=False):
480 if type(line) is list: 481 splitline = line 482 else: 483 splitline=list(line.rstrip()) 484 (trellis,back,path,nice_path_string) = decode(splitline, A,B,states, start, predict) 485 print_viterbi_table(trellis,splitline,states, start) 486 print 'nice path: %s' % nice_path_string 487 return (trellis,back,path,nice_path_string)
488