Package PyDSTool :: Package Toolbox :: Module PRCtools
[hide private]
[frames] | no frames]

Source Code for Module PyDSTool.Toolbox.PRCtools

  1  """Toolbox for phase response curves measured by finite perturbations 
  2  """ 
  3   
  4  import numpy as np 
  5  from PyDSTool import Pointset, Model, embed 
  6  from PyDSTool.Trajectory import pointset_to_traj 
  7   
  8  from PyDSTool.matplotlib_import import * 
  9   
 10  #__all__ = [] 
 11   
 12  # ------------------------------------------------- 
 13   
14 -def one_period_traj(model, ev_name, ev_t_tol, ev_norm_tol, T_est, 15 verbose=False, initial_settle=6, restore_old_ics=False, 16 use_quadratic_interp=False):
17 """ 18 Utility to extract a single period of a limit cycle of the model using forward 19 integration, up to a tolerance given in terms of both the period and the norm of the 20 vector of variables in the limit cycle at the period endpoints. 21 22 Requires a non-terminal event in the model that is detected exactly once per period. 23 Assumes model initial conditions are already in domain of attraction for limit cycle. 24 25 T_est is an initial estimate of period. 26 use_quadratic_interp option (default False) indicates whether to make the returned 27 trajectory interpolated more accurately using quadratic functions rather than linear ones. 28 This option takes a lot longer to complete! 29 30 The model argument can be an instance of a Generator class or Model class. 31 32 Returned trajectory will have name 'one_period'. 33 """ 34 if not isinstance(model, Model.Model): 35 # temporarily embed into a model object 36 model = embed(model) 37 if use_quadratic_interp: 38 old_interp_setting = model.query('algparams')['poly_interp'] 39 model.set(algparams={'poly_interp': True}) 40 trajname = '_test_period_' 41 old_ics = model.query('ics') 42 settle = initial_settle 43 tries = 1 44 success = False 45 while not success and tries < 8: 46 model.compute(trajname=trajname, tdata=[0,T_est*(settle+0.2)], force=True) 47 evts = model.getTrajEventTimes(trajname, ev_name) 48 all_evs = model.getTrajEventTimes(trajname) 49 if len(evts) <= 2: 50 raise RuntimeError("Not enough events found") 51 ref_ic = model(trajname, evts[-1]) 52 t_check = 10000*np.ones((tries,),float) 53 norm_check = 10000*np.ones((tries,),float) 54 T = np.zeros((tries,),float) 55 look_range = range(1, min((tries+1, len(evts)))) 56 if verbose: 57 print "\n Tries: ", tries, "\n" 58 for lookback in look_range: 59 try: 60 d_evts = [evts[i]-evts[i-lookback] for i in \ 61 range(lookback, len(evts))] 62 except KeyError: 63 # no more events left to look back at 64 break 65 else: 66 prev_val = model(trajname, evts[-(1+lookback)]) 67 t_check[lookback-1] = abs(d_evts[-1]-d_evts[-2]) 68 norm_check[lookback-1] = np.linalg.norm(ref_ic-prev_val) 69 T[lookback-1] = d_evts[-1] 70 T_est = T[0] 71 t_ix = np.argmin(t_check) 72 n_ix = np.argmin(norm_check) 73 ix1 = min((t_ix, n_ix)) 74 ix2 = max((t_ix, n_ix)) 75 if verbose: 76 print t_check, norm_check, T 77 print ix1, ix2 78 if t_check[ix1] < ev_t_tol and norm_check[ix1] < ev_norm_tol: 79 success = True 80 T_final = T[ix1] 81 elif ix1 != ix2 and t_check[ix2] < ev_t_tol and norm_check[ix2] < ev_norm_tol: 82 success = True 83 T_final = T[ix2] 84 else: 85 tries += 1 86 settle = tries*2 87 model.set(ics = ref_ic) 88 if success: 89 model.set(ics=ref_ic, tdata=[0, T_final]) 90 model.compute(trajname='one_period', force=True) 91 ref_traj = model['one_period'] 92 # insert the ON event at beginning of traj 93 ref_traj.events[ev_name] = Pointset(indepvararray=[0], 94 coordarray=np.array([ref_ic.coordarray]).T, 95 coordnames=ref_ic.coordnames) 96 ref_pts = ref_traj.sample() 97 # restore old ICs 98 if restore_old_ics: 99 model.set(ics=old_ics) 100 if use_quadratic_interp: 101 model.set(algparams={'poly_interp': old_interp_setting}) 102 return ref_traj, ref_pts, T_final 103 else: 104 print "norm check was", norm_check 105 print "t check was", t_check 106 raise RuntimeError("Failure to converge after 80 iterations")
107 108
109 -def _default_pert(model, ic, pertcoord, pertsize, t0):
110 ic[pertcoord] += pertsize 111 return ic
112 113
114 -def finitePRC(model, ref_traj_period, evname, pertcoord, pertsize=0.05, 115 settle=5, verbose=False, skip=1, do_pert=_default_pert, 116 keep_trajs=False, stop_at_t=np.inf, force_T=np.nan):
117 """Return a Pointset with dependent variable 'D_phase', measured from 0 to 1, 118 where D_phase > 0 is an advance. 119 120 Pass a Generator or Model instance for model. 121 Pass a Trajectory or Pointset for the ref_traj_period argument. 122 Pass the event name in the model that indicates the periodicity. 123 Use skip > 1 to sub-sample the points computed along the trajectory at 124 the skip rate. 125 Use a do_pert function to do any non-standard perturbation, e.g. if there 126 are domain boundary conditions that need special treatment. This function 127 takes four or five arguments (model, ic, pertcoord, pertsize, perttime=None) 128 and returns the new point ic (not just ic[pertcoord]). 129 Use settle=0 to perform no forward integration before the time window in 130 which the perturbation will be applied, or a fraction < 1 to ensure an 131 integration past the event point (e.g. for non-cycles). 132 Use stop_at_t to calculate a partial PRC, from perturbation time 0 to this 133 value. 134 Use force_T to force the period to be whatever value you like. 135 136 Note: Depending on your model, there may be regions of the PRC that are 137 offset by a constant amount to the rest of the PRC. This is a "wart" that 138 needs improvement. 139 """ 140 tag_pts = False 141 if not isinstance(model, Model.Model): 142 # temporarily embed into a model object 143 model = embed(model) 144 if keep_trajs: 145 tag_pts = True 146 print "Note: model object will be stored in PRC attribute _model" 147 try: 148 all_pts = ref_traj_period.sample() 149 ref_pts = all_pts[::skip] 150 if ref_pts[-1] != all_pts[-1]: 151 # ensure last point at t=T is present 152 ref_pts.append(all_pts[[-1]]) 153 if np.isnan(force_T): 154 T = ref_traj_period.indepdomain[1]-ref_traj_period.indepdomain[0] 155 else: 156 T = force_T 157 except AttributeError: 158 # already passed points 159 ref_pts = ref_traj_period[::skip] 160 if ref_pts[-1] != ref_traj_period[-1]: 161 ref_pts.append(ref_traj_period[[-1]]) 162 if np.isnan(force_T): 163 T = ref_traj_period.indepvararray[-1]-ref_traj_period.indepvararray[0] 164 else: 165 T = force_T 166 ref_ts = ref_pts.indepvararray 167 PRCvals = [] 168 t_off = 0 169 if verbose: 170 print "Period T =", T 171 for i, t0 in enumerate(ref_ts): 172 if t0 > stop_at_t: 173 break 174 ic = do_pert(model, ref_pts[i], pertcoord, pertsize, t0) 175 if verbose: 176 print i, "of", len(ref_ts), ": t0 = ", t0, "of", T, " t_end", settle*T+t0 177 print " ", ic 178 model.set(ics=ic.copy(), tdata=[0, (settle+1)*T+t0]) 179 if keep_trajs: 180 model.compute(trajname='pert_%i'%i, force=True) 181 evts = model.getTrajEventTimes('pert_%i'%i, evname) 182 else: 183 model.compute(trajname='pert', force=True) 184 evts = model.getTrajEventTimes('pert', evname) 185 if verbose: 186 print " Last event:", evts[-1] 187 if i == 0: 188 # make sure to always use the same event number 189 evnum = max(0,len(evts)-2) 190 val = (T-np.mod(evts[evnum]+t0, T))/T 191 ## assume continuity of PRC: hack-fix modulo wart by testing these vals 192 # and using the closest to previous value 193 if i > 0: 194 test_vals = np.array([val-2, val-1, val-0.5, val, val+0.5, val+1, val+2]) 195 m = np.argmin(abs(PRCvals[-1] - test_vals)) 196 val = test_vals[m] 197 if verbose and abs(PRCvals[-1] - val) > 0.05: 198 print "\nCorrected value", i, PRCvals, val 199 else: 200 # i = 0. Check that value is adjusted to be closest to zero, 201 # given that we assume the minimum will be at the beginning of the run. 202 test_vals = np.array([val-1, val, val+1]) 203 m = np.argmin(abs(test_vals)) 204 val = test_vals[m] 205 PRCvals.append(val) 206 PRC = Pointset(coordarray=[PRCvals], coordnames=['D_phase'], 207 indepvararray=ref_ts[:len(PRCvals)], indepvarname='t') 208 if tag_pts: 209 PRC._model = model 210 else: 211 PRC._model = None 212 return PRC
213 214
215 -def compare_pert(model, ref_traj_period, evname, pertcoord, pertsize, t0, settle=5, 216 do_pert=_default_pert, fignum=None):
217 """Show perturbed and un-perturbed trajectories starting at t0, with given perturbation function 218 do_pert. 219 """ 220 if not isinstance(model, Model.Model): 221 # temporarily embed into a model object 222 model = embed(model) 223 try: 224 all_pts = ref_traj_period.sample() 225 except AttributeError: 226 raise TypeError("Must pass a reference trajectory object") 227 else: 228 T = ref_traj_period.indepdomain[1]-ref_traj_period.indepdomain[0] 229 ref_ts = all_pts.indepvararray 230 231 assert t0 > ref_ts[0] and t0 < ref_ts[-1], "t0 out of range" 232 ic = do_pert(model, ref_traj_period(t0), pertcoord, pertsize) 233 model.set(ics=ic, tdata=[0,settle*T+t0]) 234 model.compute(trajname='compare_pert', force=True) 235 evts = model.getTrajEventTimes('compare_pert', evname) 236 PRC_val = -np.mod(evts[-1]+t0, T)/T 237 if abs(PRC_val) > 0.5: 238 PRC_val += 1 239 print "t0 = %.6f, PRC value = %f" % (t0, PRC_val) 240 241 if fignum is not None: 242 figure(fignum) 243 plot(all_pts['t'], all_pts[pertcoord], 'g') 244 while all_pts.indepvararray[-1] < settle*T+t0: 245 # plot additional periodic cycles to compare with pert traj 246 all_pts.indepvararray += T 247 plot(all_pts['t'], all_pts[pertcoord], 'g') 248 pert_pts = model.sample('compare_pert') 249 pert_pts.indepvararray += t0 250 plot(pert_pts['t'], pert_pts[pertcoord], 'r') 251 return pert_pts
252 253
254 -def fix_PRC(PRC, tol=0.01):
255 """Experimental and hopefully unnecessary utility to 'fix' PRC data with phase jumps 256 of 1/2 or a whole cycle due to computational problems.""" 257 new_vals = [] 258 for ix, phase in enumerate(PRC['D_phase']): 259 if phase > 0.5-tol: 260 phase = phase - 0.5 261 elif phase < tol-0.5: 262 phase = phase + 0.5 263 else: 264 continue 265 # drop through 266 new_vals.append( (ix, phase) ) 267 for ix, phase in new_vals: 268 PRC.coordarray[0][ix] = phase 269 return PRC
270