Package PyDSTool :: Package Generator :: Module Vode_ODEsystem'
[hide private]
[frames] | no frames]

Source Code for Module PyDSTool.Generator.Vode_ODEsystem'

  1  """VODE integrator for ODE systems, imported from a mild modification of
 
  2  the scipy-wrapped VODE Fortran solver.
 
  3  """ 
  4  from __future__ import division 
  5  
 
  6  from allimports import * 
  7  from PyDSTool.Generator import ODEsystem as ODEsystem 
  8  from baseclasses import Generator, theGenSpecHelper, _pollInputs 
  9  from PyDSTool.utils import * 
 10  from PyDSTool.common import * 
 11  from PyDSTool.scipy_ode import ode 
 12  
 
 13  # Other imports
 
 14  from numpy import Inf, NaN, isfinite, sometrue, alltrue, sign, all, any, \
 
 15       array, zeros, less_equal, transpose, concatenate, asarray, linspace 
 16  try: 
 17      from numpy import unique 
 18  except ImportError: 
 19      # older version of numpy
 
 20      from numpy import unique1d as unique 
 21  import math, random 
 22  from copy import copy, deepcopy 
 23  import os, platform, shutil, sys, gc 
 24  
 
 25  # PSYCO FUNCTIONS DON'T WORK AS VODE CALLBACK FUNCTIONS!
 
 26  #try:
 
 27  #    # use psyco JIT byte-compiler optimization, if available
 
 28  #    import psyco
 
 29  #    HAVE_PSYCO = True
 
 30  #except ImportError:
 
 31  #    HAVE_PSYCO = False
 
 32  HAVE_PSYCO = False 
 33  
 
 34  
 
35 -class Vode_ODEsystem(ODEsystem):
36 """Wrapper for VODE, from SciPy. 37 38 Uses Python target language only for functional specifications.""" 39 40 _paraminfo = {'init_step': 'Fixed step size for time mesh.', 41 'strictdt': 'Boolean determining whether to evenly space time mesh (default=False), or to use exactly dt spacing.', 42 'stiff': 'Boolean to activate the BDF method, otherwise Adams method used. Default False.', 43 'use_special': "Switch for using special times", 44 'specialtimes': "List of special times to use during integration",} 45
46 - def __init__(self, kw):
47 ODEsystem.__init__(self, kw) 48 49 self.diagnostics._errorcodes = \ 50 {0: 'Unrecognized error code returned (see stderr output)', 51 -1: 'Excess work done on this call. (Perhaps wrong method MF.)', 52 -2: 'Excess accuracy requested. (Tolerances too small.)', 53 -3: 'Illegal input detected. (See printed message.)', 54 -4: 'Repeated error test failures. (Check all input.)', 55 -5: 'Repeated convergence failures. (Perhaps bad' 56 ' Jacobian supplied or wrong choice of method MF or tolerances.)', 57 -6: 'Error weight became zero during problem. (Solution' 58 ' component i vanished, and ATOL or ATOL(i) = 0.)' 59 } 60 self.diagnostics.outputStatsInfo = {'errorStatus': 'Error status on completion.'} 61 # note: VODE only supports array atol, not rtol. 62 algparams_def = {'poly_interp': False, 63 'rtol': 1e-9, 64 'atol': [1e-12 for dimix in xrange(self.dimension)], 65 'stiff': False, 66 'max_step': 0.0, 67 'min_step': 0.0, 68 'init_step': 0.01, 69 'max_pts': 1000000, 70 'strictdt': False, 71 'use_special': False, 72 'specialtimes': [] 73 } 74 for k, v in algparams_def.iteritems(): 75 if k not in self.algparams: 76 self.algparams[k] = v
77 78
79 - def addMethods(self):
80 # override to add _solver function 81 ODEsystem.addMethods(self, usePsyco=False) 82 if self.haveJacobian(): 83 self._solver = ode(getattr(self,self.funcspec.spec[1]), 84 getattr(self,self.funcspec.auxfns["Jacobian"][1])) 85 self._funcreg['_solver'] = ('self', 86 'ode(getattr(self,self.funcspec.spec[1]),' \ 87 + 'getattr(self,self.funcspec.auxfns["Jacobian"][1]))') 88 else: 89 self._solver = ode(getattr(self,self.funcspec.spec[1])) 90 self._funcreg['_solver'] = ('self', 'ode(getattr(self,' \ 91 + 'self.funcspec.spec[1]))')
92 93
94 - def _debug_snapshot(self, solver, dt, inputlist):
95 ivals = [i(solver.t) for i in inputlist] 96 s = "\n***************\nNew t, x, inputs: " + " ".join([str(s) for s in (solver.t,solver.y,ivals)]) 97 s += "\ndt="+str(dt)+" f_params="+str(solver.f_params)+" dx/dt=" 98 s += str(solver.f(solver.t, solver.y, sortedDictValues(self.pars)+ivals)) 99 if solver.t > 7: 100 s += "\nfirst, max, min steps =" + str([solver.first_step, solver.max_step, solver.min_step]) 101 return s
102
103 - def compute(self, trajname, dirn='f', ics=None):
104 continue_integ = ODEsystem.prepDirection(self, dirn) 105 if self._dircode == -1: 106 raise NotImplementedError('Backwards integration is not implemented') 107 if ics is not None: 108 self.set(ics=ics) 109 self.validateICs() 110 self.diagnostics.clearWarnings() 111 self.diagnostics.clearErrors() 112 pnames = sortedDictKeys(self.pars) 113 xnames = self._var_ixmap # ensures correct order 114 # Check i.c.'s are well defined (finite) 115 self.checkInitialConditions() 116 if self.algparams['stiff']: 117 methstr = 'bdf' 118 methcode = 2 119 else: 120 methstr = 'adams' 121 methcode = 1 122 haveJac = int(self.haveJacobian()) 123 if isinstance(self.algparams['atol'], list): 124 if len(self.algparams['atol']) != self.dimension: 125 raise ValueError('atol list must have same length as phase ' 126 'dimension') 127 else: 128 atol = self.algparams['atol'] 129 self.algparams['atol'] = [atol for dimix in xrange(self.dimension)] 130 indepdom0, indepdom1 = self.indepvariable.depdomain.get() 131 if continue_integ: 132 if indepdom0 > self._solver.t: 133 print "Previous end time is %f"%self._solver.t 134 raise ValueError("Start time not correctly updated for " 135 "continuing orbit") 136 x0 = self._solver.y 137 indepdom0 = self._solver.t 138 self.indepvariable.depdomain.set((indepdom0,indepdom1)) 139 else: 140 x0 = sortedDictValues(self.initialconditions, 141 self.funcspec.vars) 142 t0 = indepdom0 143 if self.algparams['use_special']: 144 tmesh_special = list(self.algparams['specialtimes']) 145 if continue_integ: 146 if self._solver.t not in tmesh_special: 147 raise ValueError("Invalid time to continue integration:" 148 "it is not in 'special times'") 149 tmesh_special = tmesh_special[tmesh_special.index(self._solver.t):] 150 try: 151 dt = min([tmesh_special[1]-t0, self.algparams['init_step']]) 152 except: 153 raise ValueError("Invalid list of special times") 154 if not isincreasing(tmesh_special): 155 raise ValueError("special times must be given in increasing " 156 "order") 157 if self.indepvariable.depdomain.contains(t0) is notcontained or \ 158 self.indepvariable.depdomain.contains(tmesh_special[-1]) is notcontained: 159 raise PyDSTool_BoundsError("special times were outside of independent " 160 "variable bounds") 161 else: 162 tmesh_special = [] 163 dt = self.algparams['init_step'] 164 # speed up repeated access to solver by making a temp name for it 165 solver = self._solver 166 if solver._integrator is None: 167 # Banded Jacobians not yet supported 168 # 169 # start a new integrator, because method may have been 170 # switched 171 solver.set_integrator('vode', method=methstr, 172 rtol=self.algparams['rtol'], 173 atol=self.algparams['atol'], 174 nsteps=self.algparams['max_pts'], 175 max_step=self.algparams['max_step'], 176 min_step=self.algparams['min_step'], 177 first_step=dt/2., 178 with_jacobian=haveJac) 179 else: 180 solver.with_jacobian = haveJac 181 # self.mu = lband 182 # self.ml = uband 183 solver.rtol = self.algparams['rtol'] 184 solver.atol = self.algparams['atol'] 185 solver.method = methcode 186 # self.order = order 187 solver.nsteps = self.algparams['max_pts'] 188 solver.max_step = self.algparams['max_step'] 189 solver.min_step = self.algparams['min_step'] 190 solver.first_step = dt/2. 191 solver.set_initial_value(x0, t0) 192 ## Broken code for going backwards (doesn't respect 'continue' option 193 ## either) 194 ## if self._dircode == 1: 195 ## solver.set_initial_value(x0, t0) 196 ## else: 197 ## solver.set_initial_value(x0, indepdom1) 198 # wrap up each dictionary initial value as a singleton list 199 alltData = [t0] 200 allxDataDict = dict(zip(xnames, map(listid, x0))) 201 plist = sortedDictValues(self.pars) 202 extralist = copy(plist) 203 if self.inputs: 204 # inputVarList is a list of Variables 205 inames = sortedDictKeys(self.inputs) 206 listend = self.numpars + len(self.inputs) 207 inputVarList = sortedDictValues(self.inputs) 208 ilist = _pollInputs(inputVarList, alltData[0]+self.globalt0, 209 self.checklevel) 210 else: 211 ilist = [] 212 inames = [] 213 listend = self.numpars 214 inputVarList = [] 215 extralist.extend(ilist) 216 solver.set_f_params(extralist) 217 if haveJac: 218 solver.set_jac_params(extralist) 219 do_poly = self.algparams['poly_interp'] 220 if do_poly: 221 rhsfn = getattr(self, self.funcspec.spec[1]) 222 dx0 = rhsfn(t0, x0, extralist) 223 alldxDataDict = dict(zip(xnames, map(listid, dx0))) 224 auxvarsfn = getattr(self,self.funcspec.auxspec[1]) 225 strict = self.algparams['strictdt'] 226 # Make t mesh if it wasn't given as 'specialtimes' 227 if not all(isfinite(self.indepvariable.depdomain.get())): 228 print "Time domain was: ", self.indepvariable.depdomain.get() 229 raise ValueError("Ensure time domain is finite") 230 if dt == indepdom1 - indepdom0: 231 # single-step integration required 232 # special times will not have been set (unless trivially 233 # they are [indepdom0, indepdom1]) 234 tmesh = [indepdom0, indepdom1] 235 else: 236 notDone = True 237 repeatTol = 10 238 count = 0 239 while notDone and count <= repeatTol: 240 try: 241 tmesh = self.indepvariable.depdomain.sample(dt, 242 strict=strict, 243 avoidendpoints=self.checklevel>2) 244 notDone = False 245 except AssertionError: 246 count += 1 247 dt = dt/3.0 248 if count == repeatTol: 249 raise AssertionError("Supplied time step is too large for " 250 "selected time interval") 251 # incorporate tmesh_special, if used, ensuring uniqueness 252 if tmesh_special != []: 253 tmesh.extend(tmesh_special) 254 tmesh = list(unique(tmesh)) 255 tmesh.sort() 256 if len(tmesh)<=2: 257 # safety net, in case too few points in mesh 258 # too few points unless we can add endpoint 259 if tmesh[-1] != indepdom1: 260 # dt too large for tmesh to have more than one point 261 tmesh.append(indepdom1) 262 if not strict: # get actual time step used 263 # don't use [0] in case avoided end points 264 try: 265 dt = tmesh[2]-tmesh[1] 266 except IndexError: 267 # can't avoid end points for such a small mesh 268 dt = tmesh[1]-tmesh[0] 269 #if self.eventstruct.query(['lowlevel']) != []: 270 # raise ValueError("Only high level events can be passed to VODE") 271 eventslist = self.eventstruct.query(['active', 'notvarlinked']) 272 termevents = self.eventstruct.query(['term'], eventslist) 273 # reverse time by reversing mesh doesn't work 274 ## if self._dircode == -1: 275 ## tmesh.reverse() 276 tmesh.pop(0) # get rid of first entry for initial condition 277 xnames = self.funcspec.vars 278 # storage of all auxiliary variable data 279 allaDataDict = {} 280 anames = self.funcspec.auxvars 281 avals = auxvarsfn(t0, x0, extralist) 282 for aix in range(len(anames)): 283 aname = anames[aix] 284 try: 285 allaDataDict[aname] = [avals[aix]] 286 except IndexError: 287 print "\nVODE generator: There was a problem evaluating " \ 288 + "an auxiliary variable" 289 print "Debug info: avals (length", len(avals), ") was ", avals 290 print "Index out of range was ", aix 291 print self.funcspec.auxspec[1] 292 print hasattr(self, self.funcspec.auxspec[1]) 293 print "Args were:", [t0, x0, extralist] 294 raise 295 # Initialize signs of event detection objects at IC 296 self.setEventICs(self.initialconditions, self.globalt0) 297 if self.inputs: 298 parsinps = copy(self.pars) 299 parsinps.update(dict(zip(inames,ilist))) 300 else: 301 parsinps = self.pars 302 if eventslist != []: 303 dataDict = copy(self.initialconditions) 304 dataDict.update(dict(zip(anames, avals))) 305 dataDict['t'] = t0 306 evsflagged = self.eventstruct.pollHighLevelEvents(None, 307 dataDict, 308 parsinps, 309 eventslist) 310 if len(evsflagged) > 0: 311 raise RuntimeError("Some events flagged at initial condition") 312 if continue_integ: 313 # revert to prevprevsign, since prevsign changed after call 314 self.eventstruct.resetHighLevelEvents(t0, eventslist, 'prev') 315 elif self._for_hybrid_DS: 316 # self._for_hybrid_DS is set internally by HybridModel class 317 # to ensure not to reset events, because they may be about to 318 # flag on first step if previous hybrid state was the same 319 # generator and, for example, two variables are synchronizing 320 # so that their events get very close together. 321 # Just reset the starttimes of these events 322 for evname, ev in eventslist: 323 ev.starttime = t0 324 else: 325 # default state is a one-off call to this generator 326 self.eventstruct.resetHighLevelEvents(t0, eventslist, None) 327 self.eventstruct.validateEvents(self.funcspec.vars + \ 328 self.funcspec.auxvars + \ 329 self.funcspec.inputs + \ 330 ['t'], eventslist) 331 # temp storage of first time at which terminal events found 332 # (this is used for keeping the correct end point of new mesh) 333 first_found_t = None 334 # list of precise non-terminal events to be resolved after integration 335 nontermprecevs = [] 336 evnames = [ev[0] for ev in eventslist] 337 lastevtime = {}.fromkeys(evnames, None) 338 # initialize new event info dictionaries 339 Evtimes = {} 340 Evpoints = {} 341 if continue_integ: 342 for evname in evnames: 343 try: 344 # these are in global time, so convert to local time 345 lastevtime[evname] = self.eventstruct.Evtimes[evname][-1] \ 346 - self.globalt0 347 except (IndexError, KeyError): 348 # IndexError: Evtimes[evname] was None 349 # KeyError: Evtimes does not have key evname 350 pass 351 for evname in evnames: 352 Evtimes[evname] = [] 353 Evpoints[evname] = [] 354 # temp storage for repeatedly used object attributes (for lookup efficiency) 355 depdomains = dict(zip(range(self.dimension), 356 [self.variables[xn].depdomain for xn in xnames])) 357 # Main integration loop 358 num_points = 0 359 breakwhile = False 360 while not breakwhile: 361 try: 362 new_t = tmesh.pop(0) # this destroys tmesh for future use 363 except IndexError: 364 # empty 365 break 366 last_t = solver.t # a record of previous time for use by event detector 367 try: 368 errcode = solver.integrate(new_t) 369 except: 370 print "Error calling right hand side function:" 371 self.showSpec() 372 print "Numerical traceback information (current state, " \ 373 + "parameters, etc.)" 374 print "in generator dictionary 'traceback'" 375 self.traceback = {'vars': dict(zip(xnames,solver.y)), 376 'pars': dict(zip(pnames,plist)), 377 'inputs': dict(zip(inames,ilist)), 378 self.indepvariable.name: new_t} 379 raise 380 avals = auxvarsfn(new_t, solver.y, extralist) 381 # Uncomment the following assertion for debugging 382 # assert all([isfinite(a) for a in avals]), \ 383 # "Some auxiliary variable values not finite" 384 if eventslist != []: 385 dataDict = dict(zip(xnames,solver.y)) 386 dataDict.update(dict(zip(anames,avals))) 387 dataDict['t'] = new_t 388 if self.inputs: 389 parsinps = copy(self.pars) 390 parsinps.update(dict(zip(inames, 391 _pollInputs(inputVarList, new_t+self.globalt0, 392 self.checklevel)))) 393 else: 394 parsinps = self.pars 395 evsflagged = self.eventstruct.pollHighLevelEvents(None, 396 dataDict, 397 parsinps, 398 eventslist) 399 ## print new_t, evsflagged 400 ## evsflagged = [ev for ev in evsflagged if solver.t-indepdom0 > ev[1].eventinterval] 401 termevsflagged = filter(lambda e: e in evsflagged, termevents) 402 nontermevsflagged = filter(lambda e: e not in termevsflagged, 403 evsflagged) 404 # register any non-terminating events in the warnings 405 # list, unless they are 'precise' in which case flag 406 # them to be resolved after integration completes 407 if len(nontermevsflagged) > 0: 408 evnames = [ev[0] for ev in nontermevsflagged] 409 precEvts = self.eventstruct.query(['precise'], 410 nontermevsflagged) 411 prec_evnames = [e[0] for e in precEvts] 412 # first register non-precise events 413 nonprecEvts = self.eventstruct.query(['notprecise'], 414 nontermevsflagged) 415 nonprec_evnames = [e[0] for e in nonprecEvts] 416 # only record events if they have not been previously 417 # flagged within their event interval 418 if nonprec_evnames != []: 419 temp_names = [] 420 for evname, ev in nonprecEvts: 421 prevevt_time = lastevtime[evname] 422 if prevevt_time is None: 423 ignore_ev = False 424 else: 425 if solver.t-prevevt_time < ev.eventinterval: 426 ignore_ev = True 427 else: 428 ignore_ev = False 429 if not ignore_ev: 430 temp_names.append(evname) 431 lastevtime[evname] = solver.t 432 self.diagnostics.warnings.append((W_NONTERMEVENT, 433 (solver.t, temp_names))) 434 for evname in temp_names: 435 Evtimes[evname].append(solver.t) 436 xv = solver.y 437 av = array(avals) 438 Evpoints[evname].append(concatenate((xv, av))) 439 for evname, ev in precEvts: 440 # only record events if they have not been previously 441 # flagged within their event interval 442 prevevt_time = lastevtime[evname] 443 if prevevt_time is None: 444 ignore_ev = False 445 else: 446 if last_t-prevevt_time < ev.eventinterval: 447 ignore_ev = True 448 else: 449 ignore_ev = False 450 if not ignore_ev: 451 nontermprecevs.append((last_t, solver.t, (evname, ev))) 452 # be conservative as to where the event is, so 453 # that don't miss any events. 454 lastevtime[evname] = last_t # solver.t-dt 455 ev.reset() #ev.prevsign = None # 456 do_termevs = [] 457 if termevsflagged != []: 458 # only record events if they have not been previously 459 # flagged within their event interval 460 for e in termevsflagged: 461 prevevt_time = lastevtime[e[0]] 462 ## print "Event %s flagged."%e[0] 463 ## print " ... last time was ", prevevt_time 464 ## print " ... event interval = ", e[1].eventinterval 465 ## print " ... t = %f, dt = %f"%(solver.t, dt) 466 if prevevt_time is None: 467 ignore_ev = False 468 else: 469 ## print " ... comparison = %f < %f"%(solver.t-dt-prevevt_time, e[1].eventinterval) 470 if last_t-prevevt_time < e[1].eventinterval: 471 ignore_ev = True 472 ## print "VODE ignore ev" 473 else: 474 ignore_ev = False 475 if not ignore_ev: 476 do_termevs.append(e) 477 if len(do_termevs) > 0: 478 # >= 1 active terminal event flagged at this time point 479 if all([not ev[1].preciseFlag for ev in do_termevs]): 480 # then none of the events specify greater accuracy 481 # register the event in the warnings 482 evnames = [ev[0] for ev in do_termevs] 483 self.diagnostics.warnings.append((W_TERMEVENT, \ 484 (solver.t, evnames))) 485 for evname in evnames: 486 Evtimes[evname].append(solver.t) 487 xv = solver.y 488 av = array(avals) 489 Evpoints[evname].append(concatenate((xv, av))) 490 # break while loop after appending t, x 491 breakwhile = True 492 else: 493 # find which are the 'precise' events that flagged 494 precEvts = self.eventstruct.query(['precise'], 495 do_termevs) 496 # these events have flagged once so eventdelay has 497 # been used. now switch it off while finding event 498 # precisely (should be redundant after change to 499 # eventinterval and eventdelay parameters) 500 evnames = [ev[0] for ev in precEvts] 501 if first_found_t is None: 502 ## print "first time round at", solver.t 503 numtries = 0 504 first_found_t = solver.t 505 restore_evts = deepcopy(precEvts) 506 minbisectlimit = min([ev[1].bisectlimit for ev in precEvts]) 507 for ev in precEvts: 508 ev[1].eventdelay = 0. 509 else: 510 numtries += 1 511 ## print "time round: ", numtries 512 if numtries > minbisectlimit: 513 self.diagnostics.warnings.append((W_BISECTLIMIT, 514 (solver.t, evnames))) 515 breakwhile = True 516 517 # get previous time point 518 if len(alltData)>=1: 519 # take one step back -> told, which will 520 # get dt added back to first new meshpoint 521 # (solver.t is the step *after* the event was 522 # detected) 523 told = last_t # solver.t-dt without the loss of precision from subtraction 524 else: 525 raise ValueError("Event %s found too "%evnames[0]+\ 526 "close to local start time: try decreasing " 527 "initial step size (current size is " 528 "%f @ t=%f)"%(dt,solver.t+self.globalt0)) 529 530 # absolute tolerance check on event function values (not t) 531 in_tols = [abs(e[1].fval) < e[1].eventtol for e in precEvts] 532 if all(in_tols): 533 ## print "Registering event:", dt_min, dt 534 # register the event in the warnings 535 self.diagnostics.warnings.append((W_TERMEVENT, 536 (solver.t, evnames))) 537 for evname in evnames: 538 Evtimes[evname].append(solver.t) 539 xv = solver.y 540 av = array(avals) 541 Evpoints[evname].append(concatenate((xv, av))) 542 # Cannot continue -- dt_min no smaller than 543 # previous dt. If this is more than the first time 544 # in this code then have found the event to within 545 # the minimum 'precise' event's eventtol, o/w need 546 # to set eventtol smaller. 547 # Before exiting event-finding loop, reset prevsign 548 # of flagged events 549 self.eventstruct.resetHighLevelEvents(0, 550 precEvts) 551 breakwhile = True # while loop, but append point first 552 if not breakwhile: 553 dt_new = dt/5.0 554 # calc new tmesh 555 trangewidth = 2*dt #first_found_t - told 556 numpoints = int(math.ceil(trangewidth/dt_new)) 557 # choose slightly smaller dt to fit trange exactly 558 dt = trangewidth/numpoints 559 560 tmesh = [told + i*dt for i in xrange(1, numpoints+1)] 561 # linspace version is *much* slower for numpoints ~ 10 and 100 562 #tmesh = list(told+linspace(dt, numpoints*dt, numpoints)) 563 564 # reset events according to new time mesh, 565 # setting known previous event state to be their 566 # "no event found" state 567 self.eventstruct.resetHighLevelEvents(told, 568 precEvts, 569 state='off') 570 # build new ic with last good values (at t=told) 571 if len(alltData)>1: 572 new_ic = [allxDataDict[xname][-1] \ 573 for xname in xnames] 574 else: 575 new_ic = x0 576 # reset integrator 577 solver.set_initial_value(new_ic, told) 578 extralist[self.numpars:listend] = _pollInputs(inputVarList, 579 told+self.globalt0, self.checklevel) 580 solver.set_f_params(extralist) 581 # continue integrating over new mesh 582 continue # while 583 # after events have had a chance to be detected at a state boundary 584 # we check for any that have not been caught by an event (will be 585 # much less accurately determined) 586 if not breakwhile: 587 # only here if a terminal event hasn't just flagged 588 for xi in xrange(self.dimension): 589 if not self.contains(depdomains[xi], 590 solver.y[xi], 591 self.checklevel): 592 self.diagnostics.warnings.append((W_TERMSTATEBD, 593 (solver.t, 594 xnames[xi],solver.y[xi], 595 depdomains[xi].get()))) 596 breakwhile = True 597 break # for loop 598 if breakwhile: 599 break # while loop 600 alltData.append(solver.t) 601 if do_poly: 602 dxvals = rhsfn(solver.t, solver.y, extralist) 603 for xi, xname in enumerate(xnames): 604 allxDataDict[xname].append(solver.y[xi]) 605 alldxDataDict[xname].append(dxvals[xi]) 606 else: 607 for xi, xname in enumerate(xnames): 608 allxDataDict[xname].append(solver.y[xi]) 609 for aix, aname in enumerate(anames): 610 allaDataDict[aname].append(avals[aix]) 611 num_points += 1 612 if not breakwhile: 613 try: 614 extralist[self.numpars:listend] = [f(solver.t+self.globalt0, 615 self.checklevel) \ 616 for f in inputVarList] 617 except ValueError: 618 print 'External input call caused value out of range error:',\ 619 't = ', solver.t 620 for f in inputVarList: 621 if f.diagnostics.hasWarnings(): 622 print 'External input variable %s out of range:'%f.name 623 print ' t = ', repr(f.diagnostics.warnings[-1][0]), ', ', \ 624 f.name, ' = ', repr(f.diagnostics.warnings[-1][1]) 625 raise 626 except AssertionError: 627 print 'External input call caused t out of range error: t = ', \ 628 solver.t 629 raise 630 solver.set_f_params(extralist) 631 breakwhile = not solver.successful() 632 # Check that any terminal events found terminated the code correctly 633 if first_found_t is not None: 634 # ... then terminal events were found. Those that were 'precise' had 635 # their 'eventdelay' attribute temporarily set to 0. It now should 636 # be restored. 637 for evname1, ev1 in termevents: 638 # restore_evts are copies of the originally flagged 'precise' 639 # events 640 for evname2, ev2 in restore_evts: 641 if evname2 == evname1: 642 ev1.eventdelay = ev2.eventdelay 643 try: 644 if self.diagnostics.warnings[-1][0] not in [W_TERMEVENT, 645 W_TERMSTATEBD]: 646 print "t =", solver.t 647 print "state =", dict(zip(xnames,solver.y)) 648 raise RuntimeError("Event finding code for terminal event " 649 "failed in Generator " + self.name + \ 650 ": try decreasing eventdelay or " 651 "eventinterval below eventtol, or the " 652 "atol and rtol parameters") 653 except IndexError: 654 info(self.diagnostics.outputStats, "Output statistics") 655 print "t =", solver.t 656 print "x =", solver.y 657 raise RuntimeError("Event finding failed in Generator " + \ 658 self.name + ": try decreasing eventdelay " 659 "or eventinterval below eventtol, or the " 660 "atol and rtol parameters") 661 # Package up computed trajectory in Variable variables 662 # Add external inputs warnings to self.diagnostics.warnings, if any 663 for f in inputVarList: 664 for winfo in f.diagnostics.warnings: 665 self.diagnostics.warnings.append((W_NONTERMSTATEBD, 666 (winfo[0], f.name, winfo[1], 667 f.depdomain.get()))) 668 # check for non-unique terminal event 669 termcount = 0 670 for (w,i) in self.diagnostics.warnings: 671 if w == W_TERMEVENT or w == W_TERMSTATEBD: 672 termcount += 1 673 if termcount > 1: 674 self.diagnostics.errors.append((E_NONUNIQUETERM, 675 (alltData[-1], i[1]))) 676 # uncomment the following lines for debugging 677 # assert len(alltData) == len(allxDataDict.values()[0]) \ 678 # == len(allaDataDict.values()[0]), "Output data size mismatch" 679 # for val_list in allaDataDict.values(): 680 # assert all([isfinite(x) for x in val_list]) 681 # Create variables (self.variables contains no actual data) 682 # These versions of the variables are only final if no non-terminal 683 # events need to be inserted. 684 variables = copyVarDict(self.variables) 685 for x in xnames: 686 if len(alltData) > 1: 687 if do_poly: 688 xvals = array([allxDataDict[x], alldxDataDict[x]]).T 689 interp = PiecewisePolynomial(alltData, xvals, 2) 690 else: 691 xvals = allxDataDict[x] 692 interp = interp1d(alltData, xvals) 693 variables[x] = Variable(interp, 't', x, x) 694 else: 695 print "Error in Generator:", self.name 696 print "t = ", alltData 697 print "x = ", allxDataDict 698 raise PyDSTool_ValueError("Fewer than 2 data points computed") 699 for a in anames: 700 if len(alltData) > 1: 701 variables[a] = Variable(interp1d(alltData, allaDataDict[a]), 702 't', a, a) 703 else: 704 print "Error in Generator:", self.name 705 print "t = ", alltData 706 print "x = ", allxDataDict 707 raise PyDSTool_ValueError("Fewer than 2 data points computed") 708 # Resolve non-terminal 'precise' events that were flagged, using the 709 # variables created. Then, add them to a new version of the variables. 710 ntpe_tdict = {} 711 for (et0,et1,e) in nontermprecevs: 712 lost_evt = False 713 # problem if eventinterval > et1-et0 !! 714 # was: search_dt = max((et1-et0)/5,e[1].eventinterval) 715 search_dt = (et1-et0)/5 716 try: 717 et_precise_list = e[1].searchForEvents(trange=[et0,et1], 718 dt=search_dt, 719 checklevel=self.checklevel, 720 parDict=self.pars, 721 vars=variables, 722 inputs=self.inputs, 723 abseps=self._abseps, 724 eventdelay=False, 725 globalt0=self.globalt0) 726 except (ValueError, PyDSTool_BoundsError): 727 # dt too large for trange, e.g. if event tol is smaller than time mesh 728 et_precise_list = [(et0, (et0, et1))] 729 if et_precise_list == []: 730 lost_evt = True 731 for et_precise in et_precise_list: 732 if et_precise[0] is not None: 733 if et_precise[0] in ntpe_tdict: 734 # add event name at this time (that already exists in the dict) 735 ntpe_tdict[et_precise[0]].append(e[0]) 736 else: 737 # add event name at this time (when time is not already in dict) 738 ntpe_tdict[et_precise[0]] = [e[0]] 739 else: 740 lost_evt = True 741 if lost_evt: 742 print "Error: A non-terminal, 'precise' event was lost -- did you reset", 743 print "events prior to integration?" 744 raise PyDSTool_ExistError("Internal error: A non-terminal, " 745 "'precise' event '%s' was lost after integration!"%e[0]) 746 # add non-terminal event points to variables 747 if ntpe_tdict != {}: 748 # find indices of times at which event times will be inserted 749 tix = 0 750 evts = ntpe_tdict.keys() 751 evts.sort() 752 for evix in xrange(len(evts)): 753 evt = evts[evix] 754 evnames = ntpe_tdict[evt] 755 self.diagnostics.warnings.append((W_NONTERMEVENT, (evt, evnames))) 756 xval = [variables[x](evt) for x in xnames] 757 ilist = _pollInputs(inputVarList, evt+self.globalt0, 758 self.checklevel) 759 # find accurate dx and aux vars value at this time 760 if do_poly: 761 dxval = rhsfn(evt, xval, plist+ilist) 762 aval = list(auxvarsfn(evt, xval, plist+ilist)) 763 for evname in evnames: 764 Evtimes[evname].append(evt) 765 Evpoints[evname].append(array(xval+aval)) 766 tcond = less_equal(alltData[tix:], evt).tolist() 767 try: 768 tix = tcond.index(0) + tix # lowest index for t > evt 769 do_insert = (alltData[tix-1] != evt) 770 except ValueError: 771 # evt = last t value so no need to add it 772 do_insert = False 773 if do_insert: 774 alltData.insert(tix, evt) 775 for ai, a in enumerate(anames): 776 allaDataDict[a].insert(tix, aval[ai]) 777 if do_poly: 778 for xi, x in enumerate(xnames): 779 allxDataDict[x].insert(tix, xval[xi]) 780 alldxDataDict[x].insert(tix, dxval[xi]) 781 else: 782 for xi, x in enumerate(xnames): 783 allxDataDict[x].insert(tix, xval[xi]) 784 for x in xnames: 785 if do_poly: 786 # use asarray in case points added to sequences as a list 787 xvals = array([asarray(allxDataDict[x]), 788 asarray(alldxDataDict[x])]).T 789 interp = PiecewisePolynomial(alltData, xvals, 2) 790 else: 791 xvals = allxDataDict[x] 792 interp = interp1d(alltData, xvals) 793 variables[x] = Variable(interp, 't', x, x) 794 for a in anames: 795 variables[a] = Variable(interp1d(alltData, allaDataDict[a]), 796 't', a, a) 797 self.diagnostics.outputStats = {'last_step': dt, 798 'num_fcns': num_points, 799 'num_steps': num_points, 800 'errorStatus': errcode 801 } 802 if solver.successful(): 803 #self.validateSpec() 804 for evname, evtlist in Evtimes.iteritems(): 805 try: 806 self.eventstruct.Evtimes[evname].extend([et+self.globalt0 \ 807 for et in evtlist]) 808 except KeyError: 809 self.eventstruct.Evtimes[evname] = [et+self.globalt0 \ 810 for et in evtlist] 811 # build event pointset information (reset previous trajectory's) 812 self.trajevents = {} 813 for (evname, ev) in eventslist: 814 evpt = Evpoints[evname] 815 if evpt == []: 816 self.trajevents[evname] = None 817 else: 818 evpt = transpose(array(evpt)) 819 self.trajevents[evname] = Pointset({'coordnames': xnames+anames, 820 'indepvarname': 't', 821 'coordarray': evpt, 822 'indepvararray': Evtimes[evname], 823 'indepvartype': float}) 824 self.defined = True 825 return Trajectory(trajname, variables.values(), 826 abseps=self._abseps, globalt0=self.globalt0, 827 checklevel=self.checklevel, 828 FScompatibleNames=self._FScompatibleNames, 829 FScompatibleNamesInv=self._FScompatibleNamesInv, 830 events=self.trajevents, 831 modelNames=self.name, 832 modelEventStructs=self.eventstruct) 833 else: 834 try: 835 self.diagnostics.errors.append((E_COMPUTFAIL, (solver.t, 836 self.diagnostics._errorcodes[errcode]))) 837 except TypeError: 838 # e.g. when errcode has been used to return info list 839 print "Error information: ", errcode 840 self.diagnostics.errors.append((E_COMPUTFAIL, (solver.t, 841 self.diagnostics._errorcodes[0]))) 842 self.defined = False
843 844
845 - def Rhs(self, t, xdict, pdict=None, asarray=True):
846 """asarray is an unused, dummy argument for compatibility with Model.Rhs""" 847 # must convert names to FS-compatible as '.' sorts before letters 848 # while '_' sorts after! 849 # also, ensure xdict doesn't contain elements like array([4.1]) instead of 4 850 x = [float(val) for val in sortedDictValues(filteredDict(self._FScompatibleNames(xdict), 851 self.funcspec.vars))] 852 if pdict is None: 853 pdict = self.pars 854 # internal self.pars already is FS-compatible 855 p = sortedDictValues(pdict) 856 else: 857 p = sortedDictValues(self._FScompatibleNames(pdict)) 858 i = _pollInputs(sortedDictValues(self.inputs), 859 t, self.checklevel) 860 return apply(getattr(self, self.funcspec.spec[1]), [t, x, p+i])
861 862
863 - def Jacobian(self, t, xdict, pdict=None, asarray=True):
864 """asarray is an unused, dummy argument for compatibility with 865 Model.Jacobian""" 866 if self.haveJacobian(): 867 # also, ensure xdict doesn't contain elements like array([4.1]) instead of 4 868 x = [float(val) for val in sortedDictValues(filteredDict(self._FScompatibleNames(xdict), 869 self.funcspec.vars))] 870 if pdict is None: 871 pdict = self.pars 872 # internal self.pars already is FS-compatible 873 p = sortedDictValues(pdict) 874 else: 875 p = sortedDictValues(self._FScompatibleNames(pdict)) 876 i = _pollInputs(sortedDictValues(self.inputs), 877 t, self.checklevel) 878 return apply(getattr(self, self.funcspec.auxfns["Jacobian"][1]), \ 879 [t, x, p+i]) 880 else: 881 raise PyDSTool_ExistError("Jacobian not defined")
882 883
884 - def JacobianP(self, t, xdict, pdict=None, asarray=True):
885 """asarray is an unused, dummy argument for compatibility with 886 Model.JacobianP""" 887 if self.haveJacobian_pars(): 888 # also, ensure xdict doesn't contain elements like array([4.1]) instead of 4 889 x = [float(val) for val in sortedDictValues(filteredDict(self._FScompatibleNames(xdict), 890 self.funcspec.vars))] 891 if pdict is None: 892 pdict = self.pars 893 # internal self.pars already is FS-compatible 894 p = sortedDictValues(pdict) 895 else: 896 p = sortedDictValues(self._FScompatibleNames(pdict)) 897 i = _pollInputs(sortedDictValues(self.inputs), 898 t, self.checklevel) 899 return apply(getattr(self, self.funcspec.auxfns["Jacobian_pars"][1]), \ 900 [t, x, p+i]) 901 else: 902 raise PyDSTool_ExistError("Jacobian w.r.t. parameters not defined")
903 904
905 - def AuxVars(self, t, xdict, pdict=None, asarray=True):
906 """asarray is an unused, dummy argument for compatibility with 907 Model.AuxVars""" 908 # also, ensure xdict doesn't contain elements like array([4.1]) instead of 4 909 x = [float(val) for val in sortedDictValues(filteredDict(self._FScompatibleNames(xdict), 910 self.funcspec.vars))] 911 if pdict is None: 912 pdict = self.pars 913 # internal self.pars already is FS-compatible 914 p = sortedDictValues(pdict) 915 else: 916 p = sortedDictValues(self._FScompatibleNames(pdict)) 917 i = _pollInputs(sortedDictValues(self.inputs), 918 t, self.checklevel) 919 return apply(getattr(self, self.funcspec.auxspec[1]), [t, x, p+i])
920 921
922 - def __del__(self):
923 ODEsystem.__del__(self)
924 925 926 927 # Register this Generator with the database 928 929 symbolMapDict = {} 930 # in future, provide appropriate mappings for libraries math, 931 # random, etc. (for now it's left to FuncSpec) 932 theGenSpecHelper.add(Vode_ODEsystem, symbolMapDict, 'python') 933