Package PyDSTool :: Module common
[hide private]
[frames] | no frames]

Source Code for Module PyDSTool.common

   1  """
 
   2      Internal utilities.
 
   3  
 
   4      Robert Clewley, September 2005.
 
   5  """ 
   6  
 
   7  from errors import * 
   8  
 
   9  import sys, types 
  10  import numpy as npy 
  11  import scipy as spy 
  12  from scipy.optimize import minpack 
  13  # In future, will convert these specific imports to be referred as npy.X
 
  14  from numpy import Inf, NaN, atleast_1d, clip, less, greater, logical_or, \
 
  15       searchsorted, isfinite, shape, mat, sign, any, all, sometrue, alltrue, \
 
  16       array, swapaxes, zeros, ones, finfo, double, exp, log, \
 
  17       take, less_equal, putmask, ndarray, asarray, \
 
  18       int, float, complex, complexfloating, integer, floating, \
 
  19       int_, int0, int8, int16, int32, int64, float_, float32, float64, \
 
  20       complex_, complex64, complex128, argmin, argmax 
  21  from numpy.linalg import norm 
  22  from math import sqrt 
  23  
 
  24  try: 
  25      from numpy import float96 
  26  except ImportError: 
  27      _all_numpy_float = (float_, float32, float64) 
  28  else: 
  29      _all_numpy_float = (float_, float32, float64, float96) 
  30  
 
  31  
 
  32  try: 
  33      from numpy import complex192 
  34  except ImportError: 
  35      _all_numpy_complex = (complex_, complex64, complex128) 
  36  else: 
  37      _all_numpy_complex = (complex_, complex64, complex128, complex192) 
  38  
 
  39  
 
  40  import time 
  41  from copy import copy, deepcopy 
  42  import os 
  43  if os.name == 'nt': 
  44      # slow object copying for you guys
 
  45      import fixedpickle as pickle 
  46  else: 
  47      import cPickle as pickle 
  48  
 
  49  # ----------------------------------------------------------------------------
 
  50  ### EXPORTS
 
  51  
 
  52  _classes = ['Verbose', 'interpclass', 'interp0d', 'interp1d', 'Utility',
 
  53              'args', 'DefaultDict', 'Struct', 'pickle', 'Diagnostics',
 
  54              'metric', 'metric_float', 'metric_float_1D', 'metric_L2',
 
  55              'metric_L2_1D', 'metric_weighted_L2', 'metric_weighted_deadzone_L2',
 
  56              'predicate', 'null_predicate', 'and_op', 'or_op', 'not_op'] 
  57  
 
  58  _mappings = ['_num_type2name', '_num_name2type',
 
  59               '_num_equivtype', '_num_name2equivtypes',
 
  60               '_pytypefromtype', '_num_maxmin'
 
  61               ] 
  62  
 
  63  _functions = ['isUniqueSeq', 'makeArrayIxMap', 'className',
 
  64                'compareBaseClass', 'compareClassAndBases', 'timestamp',
 
  65                'makeUniqueFn', 'copyVarDict', 'concatStrDict',
 
  66                'invertMap', 'makeSeqUnique', 'insertInOrder', 'uniquePoints',
 
  67                'sortedDictKeys', 'sortedDictValues', 'sortedDictItems',
 
  68                'sortedDictLists', 'compareNumTypes', 'diff', 'diff2',
 
  69                'listid', 'idfn', 'noneFn', 'isincreasing', 'ismonotonic',
 
  70                'extent', 'n_sigdigs_str',
 
  71                'linearInterp', 'object2str', 'getSuperClasses',
 
  72                'filteredDict', 'arraymax', 'simplifyMatrixRepr',
 
  73                'makeMultilinearRegrFn', 'fit_quadratic', 'fit_quadratic_at_vertex',
 
  74                'fit_exponential', 'fit_diff_of_exp', 'fit_linear', 'fit_cubic',
 
  75                'smooth_pts', 'nearest_2n_indices',
 
  76                'KroghInterpolator', 'BarycentricInterpolator',
 
  77                'PiecewisePolynomial', 'make_poly_interpolated_curve',
 
  78                'simple_bisection', 'get_opt', 'array_bounds_check',
 
  79                'verify_intbool', 'verify_nonneg', 'verify_pos',
 
  80                'verify_values', 'ensurefloat', 'API'] 
  81  
 
  82  _constants = ['Continuous', 'Discrete', 'targetLangs', '_seq_types',
 
  83                '_num_types', '_int_types', '_float_types', '_complex_types',
 
  84                '_real_types', '_all_numpy_int', '_all_numpy_float',
 
  85                '_all_numpy_complex', '_all_int', '_all_float', '_all_complex',
 
  86                'LargestInt32'] 
  87  
 
  88  __all__ = _functions + _mappings + _classes + _constants 
  89  
 
  90  # ----------------------------------------------------------------------------
 
  91  
 
  92  # global reference for supported target languages
 
  93  targetLangs = ['c', 'python', 'matlab'] #, 'xpp', 'dstool' 
  94  
 
  95  
 
  96  # type mappings and groupings
 
  97  
 
  98  _num_types = (float, int, floating, integer) # complex, complexfloating 
  99  
 
 100  _int_types = (int, integer) 
 101  _float_types = (float, floating) 
 102  _complex_types = (complex, complexfloating) 
 103  _real_types = (int, integer, float, floating) 
 104  
 
 105  _seq_types = (list, tuple, ndarray) 
 106  
 
 107  _all_numpy_int = (int_, int0, int8, int16, int32, int64) 
 108  
 
 109  _all_int = (int, integer)+_all_numpy_int 
 110  _all_float = (float, floating)+_all_numpy_float 
 111  _all_complex = (complex, complexfloating)+_all_numpy_complex 
 112  
 
 113  LargestInt32 = 2147483647 
 114  Macheps = finfo(double).eps 
 115  
 
 116  # bind common names
 
 117  _num_type2name = {float: 'float', int: 'int'} #, complex: 'complex'} 
 118  _num_equivtype = {float: float64, int: int32} #, complex: complex128} 
 119  for f in _all_float: 
 120      _num_type2name[f] = 'float' 
 121      _num_equivtype[f] = float64 
 122  for i in _all_int: 
 123      _num_type2name[i] = 'int' 
 124      _num_equivtype[i] = int32 
 125  # Don't yet support complex numbers
 
 126  ##for c in _all_complex:
 
 127  ##    _num_type2name[c] = 'complex'
 
 128  ##    _num_equivtype[c] = complex128
 
 129  
 
 130  # equivalent types for comparison
 
 131  _num_name2equivtypes = {'float': _all_float,
 
 132                  'int': _all_int} 
 133  ##                'complex': _all_complex}
 
 134  
 
 135  # default types used by PyDSTool when named
 
 136  _num_name2type = {'float': float64, 'int': int32} #, 'complex': complex128} 
 137  
 
 138  _num_maxmin = {float64: [-Inf, Inf],
 
 139               int32: [-LargestInt32-1, LargestInt32],
 
 140  ##             complex128: [-Inf-Inf*1.0j, Inf+Inf*1.0j]
 
 141               } 
 142  
 
 143  _typefrompytype = {float: float64, int: int32} #, complex: complex128} 
 144  _pytypefromtype = {float64: float, int32: int} #, complex128: complex} 
 145  
 
 146  
 
 147  #-------------------------------------------------------------------------
 
 148  
 
 149  
 
150 -class API_class(object):
151 """Adapted from ."""
152 - def _print_values(self, obj):
153 def _print_value(key): 154 if key.startswith('_'): 155 return '' 156 value = getattr(obj, key) 157 if not hasattr(value, 'im_func'): 158 doc = type(value).__name__ 159 else: 160 if value.__doc__ is None: 161 doc = 'no docstring' 162 else: 163 doc = value.__doc__ 164 return ' %s : %s' % (key, doc)
165 res = [_print_value(el) for el in dir(obj)] 166 return '\n'.join([el for el in res 167 if el != ''])
168
169 - def __call__(self, obj):
170 if obj.__doc__ is None: 171 doc = 'No docstring' 172 else: 173 doc = obj.__doc__ 174 if hasattr(obj, '__name__'): 175 return obj.__name__ + " : " + doc + "\n\n" + \ 176 self._print_values(obj) 177 else: 178 return doc + "\n\n" + self._print_values(obj)
179 180 API = API_class() 181 182
183 -class Struct(object):
184 """The args class is a more sophisticated type of Struct. 185 """
186 - def __init__(self, **entries):
187 self.__dict__.update(entries)
188
189 - def __repr__(self):
190 attributes = [attr for attr in dir(self) if attr[0] != '_'] 191 return 'Struct(' + ', '.join(attributes) + ')'
192 193
194 -class DefaultDict(dict):
195 """Dictionary with a default value for unknown keys. 196 197 Written by Peter Norvig."""
198 - def __init__(self, default):
199 self.default = default
200
201 - def __getitem__(self, key):
202 if key in self: return self.get(key) 203 return self.setdefault(key, deepcopy(self.default))
204 205 206 ### PREDICATES ETC 207
208 -class predicate_op(object):
209 - def __init__(self, predicates):
210 self.predicates = predicates 211 self.record = []
212
213 - def precondition(self, objlist):
214 res = npy.all([p.precondition(objlist) for p in self.predicates]) 215 self.record = [(self.name, [p.record for p in self.predicates])] 216 return res
217
218 - def __call__(self, obj):
219 res = self.evaluate(obj) 220 self.record = [(self.name, [p.record for p in self.predicates])] 221 return res
222
223 - def evaluate(self, obj):
224 raise NotImplementedError
225 226
227 -class and_op(predicate_op):
228 name = 'AND' 229
230 - def evaluate(self, obj):
231 return npy.all([p(obj) for p in self.predicates])
232 233
234 -class or_op(predicate_op):
235 name = 'OR' 236
237 - def evaluate(self, obj):
238 return npy.any([p(obj) for p in self.predicates])
239 240
241 -class not_op(predicate_op):
242 name = 'NOT' 243
244 - def __init__(self, predicate):
245 self.predicate = predicate 246 self.record = []
247
248 - def precondition(self, objlist):
249 res = self.predicate.precondition(objlist) 250 self.record = [self.name, self.predicate.record] 251 return res
252
253 - def __call__(self, obj):
254 res = self.evaluate(obj) 255 self.record = [self.name, self.predicate.record] 256 return res
257
258 - def evaluate(self, obj):
259 return not self.predicate(obj)
260 261
262 -class predicate(object):
263 # override name in subclass if needed 264 name = '' 265
266 - def __init__(self, subject):
267 self.subject = subject 268 self.record = []
269
270 - def precondition(self, objlist):
271 """Override if needed""" 272 return True
273
274 - def __call__(self, obj):
275 res = self.evaluate(obj) 276 self.record = (self.name, self.subject, res) 277 return res
278
279 - def evaluate(self, obj):
280 raise NotImplementedError
281 282
283 -class null_predicate_class(predicate):
284 name = 'null' 285
286 - def evaluate(self, obj):
287 return True
288 289 null_predicate = null_predicate_class(None) 290 291 # ------------------------------------------------------ 292 293
294 -class metric(object):
295 """Abstract metric class for quantitatively comparing scalar or vector 296 quantities. 297 Can include optional explicit Jacobian function. 298 299 Create concrete sub-classes for specific applications. 300 Store the measured (*1D array only*) value in self.results for use as part 301 of a parameter estimation residual value. Residual norm will be taken 302 by optimizer routines. 303 """
304 - def __init__(self):
305 self.results = None
306
307 - def __call__(self, x, y):
308 raise NotImplementedError("Override with a concrete sub-class")
309
310 - def Jac(self, x, y):
311 raise NotImplementedError("Override with a concrete sub-class")
312 313
314 -class metric_float(metric):
315 """Simple metric between two real-valued floats. 316 """
317 - def __call__(self, x, y):
318 self.results = asarray([x - y]).flatten() 319 return norm(self.results)
320
321 -class metric_float_1D(metric):
322 """Simple metric between two real-valued floats. Version that is suitable for 323 scalar optimizers such as BoundMin. 324 """
325 - def __call__(self, x, y):
326 self.results = abs(asarray([x - y]).flatten()) 327 return self.results[0]
328
329 -class metric_L2(metric):
330 """Measures the standard "distance" between two 1D pointsets or arrays 331 using the L-2 norm."""
332 - def __call__(self, pts1, pts2):
333 self.results = asarray(pts1-pts2).flatten() 334 return norm(self.results)
335
336 -class metric_L2_1D(metric):
337 """Measures the standard "distance" between two 1D pointsets or arrays 338 using the L-2 norm."""
339 - def __call__(self, pts1, pts2):
340 norm_val = norm(asarray(pts1-pts2).flatten()) 341 self.results = array([norm_val]) 342 return norm_val
343
344 -class metric_weighted_L2(metric):
345 """Measures the standard "distance" between two 1D pointsets or arrays 346 using the L-2 norm, after weighting by weights attribute 347 (must set weights after creation, e.g. in a feature's _local_init 348 method)."""
349 - def __call__(self, pts1, pts2):
350 self.results = array(pts1-pts2).flatten()*self.weights 351 return norm(self.results)
352
353 -class metric_weighted_deadzone_L2(metric):
354 """Measures the standard "distance" between two 1D pointsets or arrays 355 using the L-2 norm, after weighting by weights attribute. 356 Then, sets distance vector entries to zero if they fall 357 below corresponding entries in the deadzone vector/scalar. 358 (Must set weights and deadzone vectors/scalars after creation, e.g. 359 in a feature's _local_init method). 360 """
361 - def __call__(self, pts1, pts2):
362 v = array(pts1-pts2).flatten()*self.weights 363 v = (abs(v) > self.deadzone).astype(int) * v 364 self.results = v 365 return norm(v)
366 367
368 -def n_sigdigs_str(x, n):
369 """Return a string representation of float x with n significant digits, 370 where n > 0 is an integer. 371 """ 372 format = "%." + str(int(n)) + "g" 373 s = '%s' % float(format % x) 374 if '.' in s: 375 # handle trailing ".0" when not one of the sig. digits 376 pt_idx = s.index('.') 377 if s[0] == '-': 378 # pt_idx is one too large 379 if pt_idx-1 >= n: 380 return s[:pt_idx] 381 else: 382 if pt_idx >= n: 383 return s[:pt_idx] 384 return s
385 386
387 -class args(object):
388 """Mapping object class for building arguments for class initialization 389 calls. Treat as a dictionary. 390 """ 391
392 - def __init__(self, **kw):
393 self.__dict__ = kw
394
395 - def _infostr(self, verbose=1, attributeTitle='args'):
396 # removed offset=0 from arg list 397 if len(self.__dict__) > 0: 398 res = "%s ("%attributeTitle 399 for k, v in self.__dict__.iteritems(): 400 try: 401 istr = v._infostr(verbose-1) #, offset+2) 402 except AttributeError: 403 istr = str(v) 404 res += "\n%s%s = %s,"%(" ",k,istr) 405 # was " "*offset 406 # skip last comma 407 res = res[:-1] + "\n)" 408 return res 409 else: 410 return "No %s defined"%attributeTitle
411
412 - def __repr__(self):
413 return self._infostr()
414
415 - def info(self):
416 print self._infostr()
417 418 __str__ = __repr__ 419
420 - def values(self):
421 return self.__dict__.values()
422
423 - def keys(self):
424 return self.__dict__.keys()
425
426 - def items(self):
427 return self.__dict__.items()
428
429 - def itervalues(self):
430 return self.__dict__.itervalues()
431
432 - def iterkeys(self):
433 return self.__dict__.iterkeys()
434
435 - def iteritems(self):
436 return self.__dict__.iteritems()
437
438 - def __getitem__(self, k):
439 return self.__dict__[k]
440
441 - def __setitem__(self, k, v):
442 self.__dict__.__setitem__(k, v)
443
444 - def update(self, d):
445 self.__dict__.update(d)
446
447 - def copy(self):
448 return copy(self)
449
450 - def clear(self):
451 self.__dict__.clear()
452
453 - def get(self, k, d=None):
454 return self.__dict__.get(k, d)
455
456 - def has_key(self, k):
457 return self.__dict__.has_key(k)
458
459 - def pop(self, k, d=None):
460 return self.__dict__.pop(k, d)
461
462 - def popitem(self):
463 raise NotImplementedError
464
465 - def __contains__(self, v):
466 return self.__dict__.__contains__(v)
467
468 - def fromkeys(self, S, v=None):
469 raise NotImplementedError
470
471 - def setdefault(self, d):
472 raise NotImplementedError
473
474 - def __delitem__(self, k):
475 del self.__dict__[k]
476
477 - def __cmp__(self, other):
478 return self.__dict__ == other
479
480 - def __eq__(self, other):
481 return self.__dict__ == other
482
483 - def __ne__(self, other):
484 return self.__dict__ != other
485
486 - def __gt__(self, other):
487 return self.__dict__ > other
488
489 - def __ge__(self, other):
490 return self.__dict__ >= other
491
492 - def __lt__(self, other):
493 return self.__dict__ < other
494
495 - def __le__(self, other):
496 return self.__dict__ <= other
497
498 - def __len__(self):
499 return len(self.__dict__)
500
501 - def __iter__(self):
502 return iter(self.__dict__)
503
504 - def __add__(self, other):
505 d = self.__dict__.copy() 506 d.update(other.__dict__) 507 return args(**d)
508 509
510 -def get_opt(argopt, attr, default=None):
511 """Get option from args object otherwise default to the given value. Can 512 also specify that an AttributeError is raised by passing default=Exception. 513 """ 514 try: 515 return getattr(argopt, attr) 516 except AttributeError: 517 if default is Exception: 518 raise PyDSTool_AttributeError("Missing option: "+attr) 519 else: 520 return default
521 522
523 -class Diagnostics(object):
524 """General purpose diagnostics manager.""" 525
526 - def __init__(self, errmessages=None, errorfields=None, warnmessages=None, 527 warnfields=None, errorcodes=None, warncodes=None, 528 outputinfo=None, propagate_dict=None):
529 if warnfields is None: 530 warnfields = {} 531 if warnmessages is None: 532 warnmessages = {} 533 if warncodes is None: 534 warncodes = {} 535 if errorfields is None: 536 errorfields = {} 537 if errmessages is None: 538 errmessages = {} 539 if errorcodes is None: 540 errorcodes = {} 541 self._warnfields = warnfields 542 self._warnmessages = warnmessages 543 self._warncodes = warncodes 544 self._errorfields = errorfields 545 self._errmessages = errmessages 546 self._errorcodes = errorcodes 547 self.errors = [] 548 self.warnings = [] 549 # traceback may store information about variable state, pars, etc. 550 # at time of an error that breaks the solver 551 self.traceback = {} 552 self.outputStatsInfo = outputinfo 553 self.outputStats = {} 554 if propagate_dict is None: 555 # use dict so that un-initialized inputs attribute 556 # of generator etc. can be passed-by-reference 557 self.propagate_dict = {} 558 else: 559 self.propagate_dict = propagate_dict
560
561 - def update(self, d):
562 """Update warnings and errors from another diagnostics object""" 563 self.traceback.update(d.traceback) 564 self.warnings.extend(d.warnings) 565 self.errors.extend(d.errors) 566 self.outputStats.update(d.outputStats) 567 self._warnfields.update(d._warnfields) 568 self._warnmessages.update(d._warnmessages) 569 self._warncodes.update(d._warncodes) 570 self._errorfields.update(d._errorfields) 571 self._errmessages.update(d._errmessages) 572 self._errorcodes.update(d._errorcodes)
573
574 - def clearAll(self):
575 self.clearErrors() 576 self.clearWarnings() 577 self.outputStats = {} 578 self.traceback = {}
579
580 - def clearWarnings(self):
581 self.warnings = [] 582 for obj in self.propagate_dict.values(): 583 try: 584 obj.diagnostics.clearWarnings() 585 except AttributeError: 586 if hasattr(obj, 'name'): 587 name = obj.name 588 else: 589 name = str(obj) 590 raise TypeError("Object %s has no diagnostics manager"%name)
591
592 - def showWarnings(self):
593 if len(self.warnings)>0: 594 print self.getWarnings()
595
596 - def getWarnings(self):
597 if len(self.warnings)>0: 598 output = 'Warnings: ' 599 for (w, d) in self.warnings: 600 dstr = '' 601 for i in range(len(d)): 602 dentry = d[i] 603 dstr += self._warnfields[w][i] + ' = ' + str(dentry) + ", " 604 dstr = dstr[:-2] # drop trailing comma 605 output += ' Warning code %s: %s\n Info: %s ' %(w, \ 606 self._warnmessages[w], dstr) 607 else: 608 output = '' 609 return output
610
611 - def findWarnings(self, code):
612 """Return time-ordered list of warnings of kind specified using a 613 single Generator warning code""" 614 res = [] 615 for wcode, (t, name) in self.warnings: 616 if wcode == code: 617 res.append((t, name)) 618 res.sort() # increasing order 619 return res
620
621 - def hasWarnings(self):
622 return self.warnings != []
623
624 - def hasErrors(self):
625 return self.errors != []
626
627 - def clearErrors(self):
628 self.errors = [] 629 for obj in self.propagate_dict.values(): 630 try: 631 obj.diagnostics.clearErrors() 632 except AttributeError: 633 if hasattr(obj, 'name'): 634 name = obj.name 635 else: 636 name = str(obj) 637 raise TypeError("Object %s has no diagnostics manager"%name)
638
639 - def showErrors(self):
640 if len(self.errors)>0: 641 print self.getErrors()
642
643 - def getErrors(self):
644 if len(self.errors)>0: 645 output = 'Errors: ' 646 for (e, d) in self.errors: 647 dstr = '' 648 for i in range(len(d)): 649 dentry = d[i] 650 dstr += self._errorfields[e][i] + ' = ' + str(dentry) + ", " 651 dstr = dstr[:-2] # drop trailing comma 652 output += ' Error code %s: %s\n Info:\n %s ' %(e, \ 653 self._errmessages[e], dstr) 654 else: 655 output = '' 656 return output
657
658 - def info(self, verboselevel=0):
659 self.showErrors() 660 self.showWarnings()
661 662 663 ## ------------------------------------------------------------------ 664 665 ## Internally used functions 666
667 -def compareNumTypes(t1, t2):
668 try: 669 return sometrue([_num_type2name[t1] == _num_type2name[t] for t in t2]) 670 except TypeError: 671 # t2 not iterable, assume singleton 672 try: 673 return _num_type2name[t1] == _num_type2name[t2] 674 except KeyError: 675 return False 676 except KeyError: 677 return False
678 679
680 -def filteredDict(d, keys, neg=False):
681 """returns filtered dictionary containing specified keys, 682 or *not* containing the specified keys if option neg=True.""" 683 out_d = {} 684 if neg: 685 out_keys = remain(d.keys(), keys) 686 else: 687 out_keys = keys 688 for k in out_keys: 689 try: 690 out_d[k] = d[k] 691 except KeyError: 692 pass 693 return out_d
694 695
696 -def concatStrDict(d, order=[]):
697 """Concatenates all entries of a dictionary (assumed to be 698 lists of strings), in optionally specified order.""" 699 retstr = '' 700 if d != {}: 701 if order == []: 702 order = d.keys() 703 for key in order: 704 itemlist = d[key] 705 for strlist in itemlist: 706 retstr += ''.join(strlist) 707 return retstr
708 709
710 -def copyVarDict(vardict, only_cts=False):
711 """Copy dictionary of Variable objects. 712 Use the only_cts Boolean optional argument (default False) to select only 713 continuous-valued variables (mainly for internal use). 714 """ 715 if only_cts: 716 out_vars = [] 717 out_varnames = [] 718 sorted_varnames = sortedDictKeys(vardict) 719 for varname in sorted_varnames: 720 var = vardict[varname] 721 if var.is_continuous_valued(): 722 out_varnames.append(varname) 723 out_vars.append(var) 724 return dict(zip(out_varnames, out_vars)) 725 else: 726 return dict(zip(sortedDictKeys(vardict), [copy(v) for v in \ 727 sortedDictValues(vardict)]))
728 729
730 -def insertInOrder(sourcelist, inslist, return_ixs=False, abseps=0):
731 """Insert elements of inslist into sourcelist, sorting these 732 lists in case they are not already in increasing order. The new 733 list is returned. 734 735 The function will not create duplicate entries in the list, and will 736 change neither the first or last entries of the list. 737 738 If sourcelist is an array, an array is returned. 739 If optional return_ixs=True, the indices of the inserted elements 740 in the returned list is returned as an additional return argument. 741 If abseps=0 (default) the comparison of elements is done exactly. For 742 abseps > 0 elements are compared up to an absolute difference no 743 greater than abseps for determining "equality". 744 """ 745 try: 746 sorted_inslist = inslist.tolist() 747 except AttributeError: 748 sorted_inslist = copy(inslist) 749 sorted_inslist.sort() 750 try: 751 sorted_sourcelist = sourcelist.tolist() 752 was_array = True 753 except AttributeError: 754 sorted_sourcelist = copy(sourcelist) 755 was_array = False 756 sorted_sourcelist.sort() 757 close_ixs = [] 758 tix = 0 759 # optimize by having separate versions of loop 760 if return_ixs: 761 ins_ixs = [] 762 for t in sorted_inslist: 763 tcond = less_equal(sorted_sourcelist[tix:], t).tolist() 764 try: 765 tix = tcond.index(0) + tix # lowest index for elt > t 766 except ValueError: 767 # no 0 value in tcond, so t might be equal to the final value 768 if abs(sorted_sourcelist[-1] - t) < abseps: 769 close_ixs.append((t,len(sorted_sourcelist)-1)) 770 else: 771 if abs(sorted_sourcelist[tix-1] - t) >= abseps: 772 if tix >= 0: 773 sorted_sourcelist.insert(tix, t) 774 ins_ixs.append(tix) 775 else: 776 close_ixs.append((t,tix-1)) 777 if was_array: 778 if abseps > 0: 779 return array(sorted_sourcelist), ins_ixs, dict(close_ixs) 780 else: 781 return array(sorted_sourcelist), ins_ixs 782 else: 783 if abseps > 0: 784 return sorted_sourcelist, ins_ixs, dict(close_ixs) 785 else: 786 return sorted_sourcelist, ins_ixs 787 else: 788 for t in sorted_inslist: 789 tcond = less_equal(sorted_sourcelist[tix:], t).tolist() 790 try: 791 tix = tcond.index(0) + tix # lowest index for elt > t 792 except ValueError: 793 # no 0 value in tcond, so t might be equal to the final value 794 if abs(sorted_sourcelist[-1] - t) < abseps: 795 close_ixs.append((t,len(sorted_sourcelist)-1)) 796 else: 797 if abs(sorted_sourcelist[tix-1] - t) >= abseps: 798 if tix >= 0: 799 sorted_sourcelist.insert(tix, t) 800 else: 801 close_ixs.append((t,tix-1)) 802 if was_array: 803 if abseps > 0: 804 return array(sorted_sourcelist), dict(close_ixs) 805 else: 806 return array(sorted_sourcelist) 807 else: 808 if abseps > 0: 809 return sorted_sourcelist, dict(close_ixs) 810 else: 811 return sorted_sourcelist
812 813
814 -def arraymax(a1,a2,t=float64):
815 """Element-wise comparison of maximum values for two arrays.""" 816 o=[] 817 try: 818 for x, y in zip(a1,a2): 819 o.append(max(x,y)) 820 except TypeError: 821 print "Problem with type of arguments in arraymax:" 822 print "Received a1 =", a1 823 print " a2 =", a2 824 raise 825 return array(o,t)
826 827
828 -def simplifyMatrixRepr(m):
829 """Convert matrix object to a compact array 830 representation or numeric value.""" 831 ma=array(m) 832 l = len(shape(ma)) 833 if l == 0: 834 return m 835 elif l>0 and shape(ma)[0] == 1: 836 return simplifyMatrixRepr(ma[0]) 837 elif l>1 and shape(ma)[1] == 1: 838 return simplifyMatrixRepr(ma[:,0]) 839 else: 840 return ma
841 842
843 -def makeMultilinearRegrFn(arg, xs, ys):
844 """Convert two lists or arrays mapping x intervals to y intervals 845 into a string function definition of a multilinear regression 846 scalar function that these define. A.k.a. makes a "piecewise 847 linear" scalar function from the input data. The two input data 848 sequences can each be either all numeric values or all 849 strings/symbolic objects, but not a mixture. """ 850 assert len(xs)==len(ys), \ 851 "You must give x and y lists that are the same length" 852 assert not isinstance(arg, _num_types), \ 853 "arg must be a string or symbolic object" 854 argname = str(arg) 855 856 def sub_str(a,b): 857 return '(' + str(a) + '-' + str(b) + ')'
858 def sub_val(a,b): 859 return repr(a-b) 860 861 def interp(n): 862 return rep_y(ys[n-1]) +'+(' + argname + '-(' + rep_x(xs[n-1]) \ 863 + '))*' + sub_y(ys[n],ys[n-1]) +'/'+ sub_x(xs[n],xs[n-1]) 864 865 x_test = [isinstance(xs[n], _num_types) for n in range(len(xs))] 866 if all(x_test): 867 rep_x = lambda x: repr(x) 868 sub_x = sub_val 869 elif any(x_test): 870 raise TypeError("xlist must contain either all string/symbolic types " 871 "or all numeric values") 872 else: 873 rep_x = lambda x: str(x) 874 sub_x = sub_str 875 y_test = [isinstance(ys[n], _num_types) for n in range(len(ys))] 876 if all(y_test): 877 rep_y = lambda y: repr(y) 878 sub_y = sub_val 879 elif any(y_test): 880 raise TypeError("ylist must contain either all string/symbolic types " 881 "or all numeric values") 882 else: 883 rep_y = lambda y: str(y) 884 sub_y = sub_str 885 mLR = '+'.join(['heav(%s-%s)*(1-heav(%s-%s))*(%s)'%(argname, \ 886 rep_x(xs[n-1]),argname,rep_x(xs[n]),interp(n)) \ 887 for n in range(1,len(xs))]) 888 return ([argname], mLR) 889 890
891 -def _scalar_diff(func, x0, dx):
892 """Numerical differentiation of scalar function by central differences. 893 Returns tuple containing derivative evaluated at x0 and error estimate, 894 using Ridders' method and Neville's algorithm. 895 """ 896 max_order = 10 897 BIG = 1e50 898 CON = 1.4 899 CON2 = CON*CON 900 SAFE = 2 901 a=zeros((max_order,max_order),'f') 902 a[0,0] = (func(x0+dx)-func(x0-dx))/(2.*dx) 903 err=BIG 904 ans = NaN 905 for i in range(1,max_order): 906 dx /= CON 907 # try a smaller stepsize 908 a[0,i] = (func(x0+dx)-func(x0-dx))/(2.*dx) 909 fac = CON2 910 for j in range(1,i): 911 # compute extrapolations of various orders, using Neville's 912 # algorithm 913 a[j,i] = (a[j-1,i]*fac-a[j-1,i-1])/(fac-1.) 914 fac *= CON2 915 errt = max([abs(a[j,i]-a[j-1,i]),abs(a[j,i]-a[j-1][i-1])]) 916 # error strategy: 917 # compare each new extrapolation to one order lower, both at the 918 # present stepsize and the previous one 919 if errt <= err: 920 err = errt 921 ans = a[j,i] 922 if abs(a[i,i] - a[i-1,i-1]) >= SAFE*err: 923 # if higher order is worse by a significant factor SAFE, then 924 # quit early 925 break 926 return (ans, err, dx)
927 928
929 -def diff(func, x0, vars=None, axes=None, eps=None, output=None):
930 """Numerical 1st derivative of R^N -> R^M scalar or array function 931 about x0 by central finite differences. Uses Ridders' method of 932 polynomial extrapolation, based on an implementation in the book 933 "Numerical Recipes". Returns a matrix. 934 935 vars argument specifies which elements of x0 are to be treated as 936 variables for the purposes of taking the Jacobian. 937 If axes argument is unused or set to be all axes, the Jacobian of the 938 function evaluated at x0 with respect to the variables is returned, 939 otherwise a sub-matrix of it is returned. 940 eps is assumed to be the scale in x for which the function varies by O(1). 941 If eps is not given an appropriate step size is chosen. 942 output = True returns an optional dictionary which will be updated 943 with error and derivative information. 944 """ 945 946 if isinstance(x0, ndarray): 947 x0type = 'array' 948 if not compareNumTypes(x0.dtype.type, _all_float): 949 raise TypeError("Only real-valued arrays valid") 950 elif isinstance(x0, _real_types): 951 x0type = 'num' 952 x0 = float(x0) 953 else: 954 # Point type 955 try: 956 assert compareNumTypes(x0.coordtype, _all_float) 957 x0.coordnames 958 x0.dimension 959 except (AssertionError, AttributeError): 960 raise TypeError("Function and x0 must use real-valued scalar," 961 "array, or Point types only") 962 x0type = 'point' 963 output_info = {} 964 if vars is None: 965 if x0type == 'array': 966 dim = len(x0) 967 vars = range(dim) 968 elif x0type == 'num': 969 dim = 1 970 vars = [0] 971 else: 972 # Point type 973 dim = x0.dimension 974 vars = x0.coordnames 975 else: 976 assert isinstance(vars, _seq_types), \ 977 "vars argument must be a sequence type" 978 if x0type in ['array', 'num']: 979 assert all(vars>=0), \ 980 "vars argument must hold non-negative integers" 981 else: 982 assert all([isinstance(vars[i], str) \ 983 for i in range(len(vars))]), "vars argument must hold strings" 984 dim = len(vars) 985 fx0 = func(x0) 986 sfx0 = shape(fx0) 987 try: 988 # ensure fx0 is a vector or at least only a D x 1 matrix 989 assert sfx0[1] == 1 990 except IndexError: 991 # if shape is of form (D,) then that's fine 992 if len(sfx0) > 0: 993 if sfx0[0] == 0: 994 raise TypeError("Invalid function return type") 995 except AssertionError: 996 print "fx0 shape is", sfx0 997 print fx0 998 raise ValueError("Function should return an N-vector or N x 1 matrix," 999 " but it returned a matrix with shape %s" % str(sfx0)) 1000 if isinstance(fx0, _float_types): 1001 dimf = 1 1002 elif isinstance(fx0, ndarray): 1003 if not compareNumTypes(fx0.dtype.type, _all_float): 1004 raise TypeError("Only real-valued functions valid") 1005 try: 1006 dimf = sfx0[0] 1007 except IndexError: 1008 dimf = 1 1009 else: 1010 try: 1011 assert compareNumTypes(fx0.coordtype, _all_float) 1012 except (AssertionError, AttributeError): 1013 raise TypeError("Only real-valued functions valid") 1014 dimf = sfx0[0] 1015 if axes is None: 1016 if x0type in ['array', 'num']: 1017 try: 1018 axes = range(sfx0[0]) 1019 except IndexError: 1020 # then singleton (scalar) was returned 1021 axes = [0] 1022 else: 1023 axes = fx0.coordnames 1024 else: 1025 assert isinstance(axes, _seq_types), \ 1026 "axes argument must be a sequence type" 1027 if x0type in ['array', 'num']: 1028 assert all(axes>=0), \ 1029 "axes argument must hold non-negative integers" 1030 else: 1031 assert all([isinstance(axes[i], str) \ 1032 for i in range(len(axes))]), "axes argument must hold strings" 1033 if eps is None: 1034 eps = sqrt(Macheps) 1035 else: 1036 assert all(eps > 0), "eps scaling array must be strictly positive" 1037 if isinstance(eps, _num_types): 1038 eps = ones(dim)*eps 1039 else: 1040 assert len(eps) == len(vars), \ 1041 "eps scaling array has length mismatch with vars" 1042 1043 if x0type == 'array': 1044 dx = eps*(abs(x0[vars]) + array(x0[vars]==zeros(dim),'float64')) 1045 elif x0type == 'num': 1046 dx = eps*(abs(x0) + int(x0==0)) 1047 else: 1048 # Point 1049 x0a = x0[vars].toarray() 1050 dx = dict(zip(vars, 1051 eps*(abs(x0a) + array(x0a==zeros(dim),'float64')))) 1052 try: 1053 dim_mat = len(axes) 1054 except TypeError: 1055 raise TypeError("axes argument must be a sequence type") 1056 assert dim_mat <= dimf, "Number of axes greater than dimension of function" 1057 df = zeros([dim_mat,dim], 'float64') 1058 if x0type == 'array': 1059 output_info['error'] = zeros([dim_mat,dim], 'float64') 1060 output_info['dx'] = zeros([dim_mat,dim], 'float64') 1061 def update(xa, i, x): 1062 xa[i] = x 1063 return xa
1064 for i, vix in enumerate(vars): 1065 try: 1066 # for numpy arrays (otherwise copy returns a regular 'array'!) 1067 x0_d = x0.copy() 1068 except AttributeError: 1069 x0_d = copy(x0) 1070 if dimf > 1: 1071 for j in range(dim_mat): 1072 f_d = lambda x: func(update(x0_d, vix, x))[axes[j]] 1073 df_d, errval, dx_d = _scalar_diff(f_d, x0_d[vix], dx[i]) 1074 df[j,i] = df_d 1075 output_info['error'][j,i] = errval 1076 output_info['dx'][j,i] = dx_d 1077 else: 1078 for j in range(dim_mat): 1079 f_d = lambda x: func(update(x0_d, vix, x)) 1080 df_d, errval, dx_d = _scalar_diff(f_d, x0_d[vix], dx[i]) 1081 df[j,i] = df_d 1082 output_info['error'][j,i] = errval 1083 output_info['dx'][j,i] = dx_d 1084 df = mat(df) 1085 output_info['df'] = df 1086 if output is not None: 1087 try: 1088 output.update(output_info) 1089 except: 1090 raise TypeError("Invalid type for 'output' argument") 1091 return df 1092 elif x0type == 'num': 1093 df, errval, dx_d = _scalar_diff(func, x0, dx) 1094 output_info['df'] = df 1095 output_info['error'] = errval 1096 output_info['dx'] = dx_d 1097 if output is not None: 1098 try: 1099 output.update(output_info) 1100 except: 1101 raise TypeError("Invalid type for 'output' argument") 1102 return df 1103 else: 1104 # Point type 1105 output_info['error'] = zeros([dim_mat,dim], 'float64') 1106 output_info['dx'] = zeros([dim_mat,dim], 'float64') 1107 def update(xa, vn, x): 1108 xa[vn] = x 1109 return xa 1110 for i in range(dim): 1111 vname = vars[i] 1112 x0_d = copy(x0) 1113 for j in range(dim_mat): 1114 f_d = lambda x: func(update(x0_d, vname, x))[axes[j]] 1115 df_d, errval, dx_d = _scalar_diff(f_d, x0_d[vname], dx[vname]) 1116 df[j,i] = df_d 1117 output_info['error'][j,i] = errval 1118 output_info['dx'][j,i] = dx_d 1119 df = mat(df) 1120 output_info['df'] = df 1121 if output is not None: 1122 try: 1123 output.update(output_info) 1124 except: 1125 raise TypeError("Invalid type for 'output' argument") 1126 return df 1127 1128
1129 -def diff2(func, x0, vars=None, axes=None, dir=1, eps=None):
1130 """Numerical 1st derivative of R^N -> R^M scalar or array function 1131 about x0 by forward or backward finite differences. Returns a matrix. 1132 1133 dir=1 uses finite forward difference. 1134 dir=-1 uses finite backward difference. 1135 List-valued eps rescales finite differencing in each axis separately. 1136 vars argument specifies which elements of x0 are to be treated as 1137 variables for the purposes of taking the Jacobian. 1138 If axes argument is unused or set to be all axes, the Jacobian of the 1139 function evaluated at x0 with respect to the variables is returned, 1140 otherwise a sub-matrix of it is returned. 1141 eps is assumed to be the scale in x for which the function varies by O(1). 1142 If eps is not given an appropriate step size is chosen 1143 (proportional to sqrt(machine precision)). 1144 """ 1145 1146 if isinstance(x0, ndarray): 1147 x0type = 'array' 1148 if not compareNumTypes(x0.dtype.type, _all_float): 1149 try: 1150 x0 = x0.astype(float) 1151 except: 1152 print "Found type:", x0.dtype.type 1153 raise TypeError("Only real-valued arrays valid") 1154 elif isinstance(x0, _real_types): 1155 x0type = 'num' 1156 x0 = float(x0) 1157 else: 1158 # Point type 1159 try: 1160 assert compareNumTypes(x0.coordtype, _all_float) 1161 x0.coordnames 1162 x0.dimension 1163 except (AssertionError, AttributeError): 1164 raise TypeError("Function and x0 must use real-valued scalar," 1165 "array, or Point types only") 1166 x0type = 'point' 1167 if vars is None: 1168 if x0type == 'array': 1169 dim = len(x0) 1170 vars = range(dim) 1171 elif x0type == 'num': 1172 dim = 1 1173 vars = [0] 1174 else: 1175 # Point type 1176 dim = x0.dimension 1177 vars = x0.coordnames 1178 else: 1179 assert isinstance(vars, _seq_types), \ 1180 "vars argument must be a sequence type" 1181 if x0type in ['array', 'num']: 1182 assert all(vars>=0), \ 1183 "vars argument must hold non-negative integers" 1184 else: 1185 assert all([isinstance(vars[i], str) \ 1186 for i in range(len(vars))]), "vars argument must hold strings" 1187 dim = len(vars) 1188 fx0 = func(x0) 1189 sfx0 = shape(fx0) 1190 if isinstance(fx0, _float_types): 1191 dimf = 1 1192 elif isinstance(fx0, ndarray): 1193 if not compareNumTypes(fx0.dtype.type, _all_float): 1194 raise TypeError("Only real-valued functions valid") 1195 try: 1196 dimf = sfx0[0] 1197 except IndexError: 1198 dimf = 1 1199 try: 1200 # ensure fx0 is a vector or at least only a D x 1 matrix 1201 assert sfx0[1] == 1 1202 except IndexError: 1203 # if shape is of form (D,) then that's fine 1204 if len(sfx0) > 0: 1205 if sfx0[0] == 0: 1206 raise TypeError("Invalid function return type") 1207 else: 1208 raise TypeError("Invalid function return type") 1209 except AssertionError: 1210 print "fx0 shape is", sfx0 1211 print fx0 1212 raise ValueError("Function should return an N-vector or N x 1 matrix," 1213 " but it returned a matrix with shape %s" % str(sfx0)) 1214 else: 1215 try: 1216 assert compareNumTypes(fx0.coordtype, _all_float) 1217 except (AssertionError, AttributeError): 1218 raise TypeError("Only real-valued functions valid") 1219 dimf = sfx0[0] 1220 if axes is None: 1221 if x0type in ['array', 'num']: 1222 try: 1223 axes = range(sfx0[0]) 1224 except IndexError: 1225 # then singleton (scalar) was returned 1226 axes = [0] 1227 else: 1228 axes = fx0.coordnames 1229 else: 1230 assert isinstance(axes, _seq_types), \ 1231 "axes argument must be a sequence type" 1232 if x0type in ['array', 'num']: 1233 assert all(axes>=0), \ 1234 "axes argument must hold non-negative integers" 1235 else: 1236 assert all([isinstance(axes[i], str) \ 1237 for i in range(len(axes))]), "axes argument must hold strings" 1238 if eps is None: 1239 eps = sqrt(Macheps) 1240 else: 1241 assert all(eps > 0), "eps scaling array must be strictly positive" 1242 if isinstance(eps, float): 1243 if x0type in ['array', 'num']: 1244 eps = ones(dim)*eps 1245 else: 1246 assert len(eps) == len(vars), \ 1247 "eps scaling array has length mismatch with vars" 1248 eps = asarray(eps, 'float64') 1249 # ensure dx is not 0, and make into an appropriate length vector 1250 if x0type == 'array': 1251 dx = eps*(abs(x0[vars]) + array(x0[vars]==zeros(dim),'float64')) 1252 elif x0type == 'num': 1253 dx = eps*(abs(x0)+int(x0==0)) 1254 else: 1255 # Point 1256 x0a = x0[vars].toarray() 1257 dx = dict(zip(vars, 1258 eps*(abs(x0a) + array(x0a==zeros(dim),'float64')))) 1259 1260 assert dir==1 or dir==-1, "Direction code must be -1 or 1" 1261 try: 1262 dim_mat = len(axes) 1263 except TypeError: 1264 raise TypeError("axes argument must be a sequence type") 1265 assert dim_mat <= dimf, "Number of axes greater than dimension of function" 1266 df = zeros([dim_mat,dim], 'float64') 1267 if x0type == 'array': 1268 for i in range(dim): 1269 vix = vars[i] 1270 try: 1271 # for numpy arrays (otherwise copy returns a regular 'array'!) 1272 x0_d = x0.copy() 1273 except AttributeError: 1274 x0_d = copy(x0) 1275 x0_d[vix] += dir * dx[i] 1276 fx0_d = func(x0_d) 1277 if dim_mat > 1: 1278 fx0_d_v = array([fx0_d[n] for n in axes]) 1279 fx0_v = array([fx0[n] for n in axes]) 1280 else: 1281 if dimf > 1: 1282 fx0_d_v = fx0_d[axes[0]] 1283 fx0_v = fx0[axes[0]] 1284 else: 1285 fx0_d_v = fx0_d 1286 fx0_v = fx0 1287 df[:,i] = dir*(fx0_d_v - fx0_v)/dx[i] 1288 return mat(df) 1289 elif x0type == 'num': 1290 x0_d = x0 + dir*dx 1291 fx0_d = func(x0_d) 1292 df = dir*(fx0_d - fx0)/dx 1293 return df 1294 else: 1295 # Point type 1296 for i in range(dim): 1297 vname = vars[i] 1298 x0_d = copy(x0) 1299 x0_d[vname] = x0_d(vname) + dir * dx[vname] 1300 fx0_d = func(x0_d)[axes] 1301 fx0_v = fx0[axes] 1302 df[:,i] = dir*(fx0_d - fx0_v).toarray()/dx[vname] 1303 return mat(df)
1304 1305
1306 -def ensurefloat(v):
1307 try: 1308 # singleton Point will return scalar here 1309 v = v.toarray() 1310 except AttributeError: 1311 pass 1312 try: 1313 # numeric literal as Quantity will return scalar here 1314 v = v.tonumeric() 1315 except AttributeError: 1316 pass 1317 return float(v)
1318 1319 _verify_type_names = {_all_int: 'an integer', 1320 _all_float: 'a float', 1321 _real_types: 'a real number', 1322 _all_complex: 'a complex number'} 1323 1324 # Only support lists because the primary use of these functions is for 1325 # checking input to SWIG-interfaced data structures passed down to C 1326 # and Fortran, which must be basic types only. 1327
1328 -def verify_values(name, value, values, list_ok=False, list_len=None):
1329 """Use list_ok if a list of values of these types is acceptable. 1330 list_len can be used to specify that a list must be of a certain length, 1331 either a fixed integer or a variable integer value given as the first 1332 value of a pair, the second being the name of the variable (for use in 1333 error messages) 1334 """ 1335 if list_ok: 1336 if isinstance(value, list): 1337 if list_len is not None: 1338 if isinstance(list_len, _all_int): 1339 ok = (len(value) == list_len) 1340 len_name = '%d' % list_len 1341 else: 1342 ok = (len(value) == list_len[0]) 1343 len_name = list_len[1] 1344 if not ok: 1345 raise ValueError("list "+name+" length must equal "+len_name) 1346 for v in value: 1347 try: 1348 # make sure v is not a list too 1349 verify_values(name, v, values) 1350 except ValueError: 1351 raise ValueError(name+" must be in " + str(values) + \ 1352 " or a list of these") 1353 except TypeError: 1354 raise TypeError(name+" must be in " + str(values) + \ 1355 " or a list of these") 1356 else: 1357 raise TypeError(name+" must be in " + str(values) + \ 1358 " or a list of these") 1359 else: 1360 if value not in values: 1361 raise ValueError(name+" must be in " + str(values))
1362 1363
1364 -def verify_intbool(name, value, list_ok=False, list_len=None):
1365 """Use list_ok if a list of values of these types is acceptable. 1366 list_len can be used to specify that a list must be of a certain length, 1367 either a fixed integer or a variable integer value given as the first 1368 value of a pair, the second being the name of the variable (for use in 1369 error messages) 1370 """ 1371 if list_ok: 1372 if isinstance(value, list): 1373 if list_len is not None: 1374 if isinstance(list_len, _all_int): 1375 ok = (len(value) == list_len) 1376 len_name = '%d' % list_len 1377 else: 1378 ok = (len(value) == list_len[0]) 1379 len_name = list_len[1] 1380 if not ok: 1381 raise ValueError("list "+name+" length must equal "+len_name) 1382 for v in value: 1383 try: 1384 # make sure v is not a list too 1385 verify_intbool(name, v) 1386 except ValueError: 1387 raise ValueError(name+" must be 0, 1, or a boolean," + \ 1388 " or a list of these") 1389 except TypeError: 1390 raise TypeError(name+" must be 0, 1, or a boolean," + \ 1391 " or a list of these") 1392 else: 1393 raise TypeError(name+" must be 0, 1, or a boolean," + \ 1394 " or a list of these") 1395 elif isinstance(value, _all_int): 1396 if value not in [0, 1]: 1397 raise ValueError("integer "+name+" must be 0 or 1") 1398 elif not isinstance(value, bool): 1399 raise TypeError(name+" must be 0, 1 or a boolean")
1400 1401
1402 -def verify_nonneg(name, value, types, list_ok=False, list_len=None):
1403 """Use list_ok if a list of values of these types is acceptable. 1404 list_len can be used to specify that a list must be of a certain length, 1405 either a fixed integer or a variable integer value given as the first 1406 value of a pair, the second being the name of the variable (for use in 1407 error messages) 1408 """ 1409 if isinstance(value, types): 1410 if value < 0: 1411 raise ValueError(name+" must be non-negative") 1412 elif list_ok: 1413 if isinstance(value, list): 1414 if list_len is not None: 1415 if isinstance(list_len, _all_int): 1416 ok = (len(value) == list_len) 1417 len_name = '%d' % list_len 1418 else: 1419 ok = (len(value) == list_len[0]) 1420 len_name = list_len[1] 1421 if not ok: 1422 raise ValueError("list "+name+" length must equal "+len_name) 1423 for v in value: 1424 try: 1425 # make sure v is not a list too 1426 verify_nonneg(name, v, types) 1427 except ValueError: 1428 raise ValueError(name+" must be "+_verify_type_names[types]+ \ 1429 " and non-negative, or a list of these") 1430 except TypeError: 1431 raise TypeError(name+" must be "+_verify_type_names[types]+ \ 1432 " and non-negative, or a list of these") 1433 else: 1434 raise TypeError(name+" must be "+_verify_type_names[types]+ \ 1435 " and non-negative, or a list of these") 1436 else: 1437 raise TypeError(name+" must be "+_verify_type_names[types]+ \ 1438 " and non-negative")
1439 1440
1441 -def verify_pos(name, value, types, list_ok=False, list_len=None):
1442 """Use list_ok if a list of values of these types is acceptable. 1443 list_len can be used to specify that a list must be of a certain length, 1444 either a fixed integer or a variable integer value given as the first 1445 value of a pair, the second being the name of the variable (for use in 1446 error messages) 1447 """ 1448 if isinstance(value, types): 1449 if value <= 0: 1450 raise ValueError(name+" must be positive") 1451 elif list_ok: 1452 if isinstance(value, list): 1453 if list_len is not None: 1454 if isinstance(list_len, _all_int): 1455 ok = (len(value) == list_len) 1456 len_name = '%d' % list_len 1457 else: 1458 ok = (len(value) == list_len[0]) 1459 len_name = list_len[1] 1460 if not ok: 1461 raise ValueError("list "+name+" length must equal "+len_name) 1462 for v in value: 1463 try: 1464 # make sure v is not a list too 1465 verify_nonneg(name, v, types) 1466 except ValueError: 1467 raise ValueError(name+" must be "+_verify_type_names[types]+ \ 1468 " and positive, or a list of these") 1469 except TypeError: 1470 raise TypeError(name+" must be "+_verify_type_names[types]+ \ 1471 " and positive, or a list of these") 1472 else: 1473 raise TypeError(name+" must be "+_verify_type_names[types]+ \ 1474 " and positive, or a list of these") 1475 else: 1476 raise TypeError(name+" must be "+_verify_type_names[types]+ \ 1477 " and positive")
1478 1479
1480 -def array_bounds_check(a, bounds, dirn=1):
1481 """Internal utility function to test a 1D array for staying within given 1482 bounds (min val, max val). 1483 1484 Returns the largest index +1 if the array is within bounds, otherwise the 1485 first offending index, where 'first' is the earliest in a if direction 1486 dirn=1, or the latest if dirn=-1.""" 1487 if dirn == 1: 1488 OK_ix = len(a) 1489 alo = asarray(a<bounds, int) 1490 ahi = asarray(a>bounds, int) 1491 alo_first = alo.argmax() 1492 ahi_first = ahi.argmax() 1493 test_val = 0 1494 compare = min 1495 elif dirn == -1: 1496 OK_ix = -1 1497 alo = 1 - asarray(a<bounds, int) 1498 ahi = 1 - asarray(a>bounds, int) 1499 alo_first = alo.argmin() 1500 ahi_first = ahi.argmin() 1501 test_val = 1 1502 compare = max 1503 else: 1504 raise ValueError("Invalid direction") 1505 first_fail_ix = OK_ix 1506 if alo[alo_first] != test_val: 1507 # an element was below lower bound 1508 first_fail_ix = alo_first 1509 if ahi[ahi_first] != test_val: 1510 # an element was above upper bound 1511 if first_fail_ix == OK_ix: 1512 first_fail_ix = ahi_first 1513 else: 1514 first_fail_ix = compare(first_fail_ix, ahi_first) 1515 return first_fail_ix
1516 1517
1518 -def linearInterp(y0, ygoal, y1, x0, x1):
1519 """Internal utility function to linearly interpolate between two 1520 data points.""" 1521 return ( x1 * (ygoal - y0) + x0 * ( y1 - ygoal) ) / (y1 - y0)
1522 1523
1524 -def makeUniqueFn(fstr, tdigits=0, idstr=None):
1525 """Add unique ID to function names. 1526 1527 Used when functions are executed in global namespace to avoid name 1528 clashes, and need to be distinguished when DS objects are copied.""" 1529 # check for syntax errors 1530 try: 1531 code = compile(fstr, 'test', 'exec') 1532 except: 1533 print " Cannot make unique function because of a syntax (or other) error " \ 1534 "in supplied code:\n" 1535 print fstr 1536 raise 1537 bracepos = fstr.index("(") 1538 if idstr is None: 1539 idstr_insert = "" 1540 else: 1541 idstr_insert = "_" + idstr 1542 if tdigits > 0: 1543 fname = fstr[4:bracepos] + idstr_insert + "_" + timestamp(tdigits) 1544 else: 1545 fname = fstr[4:bracepos] 1546 fstr_new = "def " + fname + fstr[bracepos:] 1547 return (fstr_new, fname)
1548 1549
1550 -def timestamp(tdigits=8):
1551 """Return a unique timestamp string for the session. useful for ensuring 1552 unique function identifiers, etc. 1553 """ 1554 return str(time.clock()).replace(".", "").replace("-","")[:tdigits+1]
1555 1556
1557 -def isUniqueSeq(objlist):
1558 """Check that list contains items only once""" 1559 if len(objlist) > 0: 1560 return alltrue([objlist.count(obj) == 1 for obj in objlist]) 1561 else: 1562 return True
1563
1564 -def makeSeqUnique(seq, asarray=False):
1565 """Return a 1D sequence that only contains the unique values in seq. 1566 Adapted from code by Raymond Hettinger, 2002""" 1567 set = {} 1568 if asarray: 1569 return array([set.setdefault(e,e) for e in seq if e not in set]) 1570 else: 1571 return [set.setdefault(e,e) for e in seq if e not in set]
1572 1573
1574 -def object2str(x):
1575 """Convert occurrences of types / classes, 1576 to pretty-printable strings.""" 1577 try: 1578 if type(x) in [types.InstanceType, types.TypeType]: 1579 return className(x, True) 1580 elif isinstance(x, list): 1581 # search through any iterable parts (that aren't strings) 1582 rx = "[" 1583 if len(x)>0: 1584 for o in x: 1585 rx += object2str(o) + ", " 1586 return rx[:-2]+"]" 1587 else: 1588 return rx+"]" 1589 elif isinstance(x, tuple): 1590 rx = "(" 1591 if len(x)>0: 1592 for o in x: 1593 rx += object2str(o) + ", " 1594 return rx[:-2]+")" 1595 else: 1596 return rx+")" 1597 elif isinstance(x, dict): 1598 rx = "{" 1599 if len(x)>0: 1600 for k, o in x.iteritems(): 1601 rx += object2str(k) + ": " + object2str(o) + ", " 1602 return rx[:-2]+"}" 1603 else: 1604 return rx+"}" 1605 elif isinstance(x, str): 1606 # this removes extraneous single quotes around dict keys, for instance 1607 return x 1608 else: 1609 return repr(x) 1610 except: 1611 raise TypeError("object2str cannot format this object type")
1612 1613 1614 # The class types can show different roots when they originate from 1615 # different parts of the PyDSTool package -- it might be a bug. 1616 # e.g. baseClass might be Generator.Generator, but here this type will be 1617 # <class 'PyDSTool.Generator.baseclasses.Generator'> 1618 # and input.__class__ will boil down to 1619 # <class 'Generator.baseclasses.Generator'> 1620 # even though these classes are identical (constructed from the same class 1621 # in the same module!)
1622 -def compareBaseClass(input, baseClass):
1623 """input may be a class or a class instance representing that class. 1624 baseClass may be a class or a string name of a class. 1625 1626 Comparison is made using class names only.""" 1627 if isinstance(baseClass, type): 1628 base_str = baseClass.__name__ 1629 elif isinstance(baseClass, str): 1630 base_str = baseClass 1631 else: 1632 raise TypeError("Must pass either a class or a class name (string)") 1633 if isinstance(input, type): 1634 bases = input.__bases__ 1635 else: 1636 try: 1637 bases = input.__class__.__bases__ 1638 except AttributeError: 1639 # not the kind of baseClass PyDSTool is interested in 1640 # e.g. an exception type 1641 return False 1642 return sometrue([base_str == c.__name__ for c in bases])
1643 1644
1645 -def compareClassAndBases(input, arg):
1646 """arg can be a single or sequence of classes""" 1647 try: 1648 # if arg is iterable 1649 return sometrue([compareClassAndBases(input, a) for a in arg]) 1650 except TypeError: 1651 try: 1652 if isinstance(input, type): 1653 # input is a class 1654 return issubclass(input, arg) 1655 else: 1656 # input is an instance 1657 return isinstance(input, arg) 1658 except TypeError: 1659 raise TypeError("Invalid class(es) provided: input %s vs. %s" \ 1660 %(str(input)+" of type "+className(input),className(arg,True)))
1661 1662
1663 -def getSuperClasses(obj, limitClasses=None):
1664 """Return string names of all super classes of a given object""" 1665 if limitClasses == None: 1666 limitClassNames = ['object'] 1667 elif isinstance(limitClasses, list): 1668 limitClassNames = [className(lc) for lc in limitClasses] 1669 else: 1670 # singleton class 1671 limitClassNames = [className(limitClasses)] 1672 # ensure "object" safety net is present 1673 if 'object' not in limitClassNames: 1674 limitClassNames.append('object') 1675 search_obj = [obj.__class__] 1676 sclasses = [className(search_obj[0])] 1677 # don't start while loop if obj is already of a type in limitClasses 1678 done = (sclasses[0] in limitClassNames) 1679 c = 0 1680 while not done and c < 10: 1681 c += 1 1682 search_temp = [] 1683 for so in search_obj: 1684 search_temp.extend(list(so.__bases__)) 1685 search_obj = search_temp 1686 for b in search_obj: 1687 sclass = className(b) 1688 done = sclass in limitClassNames 1689 if done: 1690 break 1691 else: 1692 sclasses.append(sclass) 1693 return sclasses
1694 1695
1696 -def className(obj, addPrefix=False):
1697 """Return human-readable string of class name.""" 1698 if isinstance(obj, str): 1699 class_str = obj 1700 # don't add prefix -- it's unknown 1701 prefix = "" 1702 elif isinstance(obj, type): 1703 class_str = obj.__name__ 1704 if addPrefix: 1705 prefix = "Class " 1706 else: 1707 prefix = "" 1708 elif isinstance(obj, types.ModuleType): 1709 class_str = obj.__name__ 1710 if addPrefix: 1711 prefix = "Module " 1712 else: 1713 prefix = "" 1714 else: 1715 try: 1716 class_str = obj.__class__.__name__ 1717 except AttributeError: 1718 class_str = str(type(obj)) 1719 prefix = "" 1720 return prefix + class_str
1721 1722 1723 # little utility function to wrap value as a singleton list
1724 -def listid(val):
1725 return [val]
1726 1727 1728 # the identity function
1729 -def idfn(val):
1730 return copy(val)
1731 1732 1733 # utility function representing a "none" function
1734 -def noneFn(x):
1735 return None
1736 1737 1738 # returns the mapping from the entries in an array or list to their indices
1739 -def makeArrayIxMap(a):
1740 return dict(zip(a, range(len(a))))
1741 1742 1743 # invert an index mapping or other form of mapping
1744 -def invertMap(themap):
1745 """invert an index mapping or other form of mapping. 1746 1747 If argument is a dict or sequence type, returns a dictionary, 1748 but if argument is a parseUtils.symbolMapClass then that type is 1749 returned.""" 1750 if isinstance(themap, dict): 1751 try: 1752 return dict(map(lambda (k,v): (v,k), themap.iteritems())) 1753 except TypeError: 1754 # e.g., list objects are unhashable 1755 # try it the slow way for this case 1756 result = {} 1757 for k, v in themap.iteritems(): 1758 if isinstance(v, (list,tuple)): 1759 for val in v: 1760 result[val] = k 1761 else: 1762 result[v] = k 1763 return result 1764 elif isinstance(themap, (list,tuple)): 1765 # input domain is the position index 1766 return dict(zip(themap, range(len(themap)))) 1767 elif isinstance(themap, ndarray): 1768 # input domain is the position index 1769 return dict(zip(themap.tolist(), range(len(themap)))) 1770 elif hasattr(themap, 'inverse'): 1771 # symbolMapClass type 1772 return themap.inverse() 1773 else: 1774 raise TypeError("Unsupported type for map")
1775 1776
1777 -def isincreasing(theseq, withVal=False):
1778 """ 1779 Check whether a sequence is in increasing order. The withVal 1780 option (default False) causes the function to return the first 1781 two offending values that are not repeated. 1782 """ 1783 # Note: This version of the function has better speed on the 1784 # 'usual' case where this function is used internally by PyDSTool 1785 # -- which is where the sequence *is* increasing and the input is 1786 # already an array 1787 try: 1788 v_old = theseq[0] 1789 except IndexError: 1790 raise ValueError("Problem with sequence passed to " 1791 "function `isincreasing` -- is it empty?") 1792 v = array(theseq) 1793 res = v[1:] > v[:-1] 1794 if withVal: 1795 if all(res): 1796 return True, None, None 1797 else: 1798 pos = res.tolist().index(False) 1799 return False, theseq[pos], theseq[pos+1] 1800 else: 1801 return all(res)
1802 1803
1804 -def ismonotonic(theseq, withVal=False):
1805 """ 1806 Check whether a sequence is in strictly increasing or decreasing 1807 order. The withVal option (default False) causes the function to 1808 return the first two offending values that are not repeated. 1809 """ 1810 if withVal: 1811 res_incr, pos1, pos2 = isincreasing(theseq, True) 1812 res_decr = isincreasing(theseq[::-1], False) 1813 if res_incr or res_decr: 1814 return True, None, None 1815 else: 1816 return False, pos1, pos2 1817 else: 1818 res_incr = isincreasing(theseq) 1819 res_decr = isincreasing(theseq[::-1]) 1820 return res_incr or res_decr
1821 1822
1823 -def extent(data):
1824 """Returns a pair of the min and max values of a dataset, or just a numeric type if these are equal. 1825 (Ignores NaNs.) 1826 """ 1827 minval = npy.nanmin(data) 1828 maxval = npy.nanmax(data) 1829 if minval == maxval: 1830 return minval 1831 else: 1832 return [minval, maxval]
1833
1834 -def uniquePoints(ar):
1835 """For an n by m array input, return only points that are unique""" 1836 result = [] 1837 seq = set() 1838 for a in ar: 1839 a = tuple(a) 1840 if a not in seq: 1841 result.append(a) 1842 seq.add(a) 1843 return array(result)
1844 1845
1846 -def sortedDictValues(d, onlykeys=None, reverse=False):
1847 """Return list of values from a dictionary in order of sorted key list. 1848 1849 Adapted from original function by Alex Martelli: 1850 added filtering of keys. 1851 """ 1852 if onlykeys is None: 1853 keys = d.keys() 1854 else: 1855 keys = intersect(d.keys(), onlykeys) 1856 keys.sort() 1857 if reverse: 1858 keys.reverse() 1859 return map(d.get, keys)
1860
1861 -def sortedDictKeys(d, onlykeys=None, reverse=False):
1862 """Return sorted list of keys from a dictionary. 1863 1864 Adapted from original function by Alex Martelli: 1865 added filtering of keys.""" 1866 if onlykeys is None: 1867 keys = d.keys() 1868 else: 1869 keys = intersect(d.keys(), onlykeys) 1870 keys.sort() 1871 if reverse: 1872 keys.reverse() 1873 return keys
1874
1875 -def sortedDictLists(d, byvalue=True, onlykeys=None, reverse=False):
1876 """Return (key list, value list) pair from a dictionary, 1877 sorted by value (default) or key. 1878 Adapted from an original function by Duncan Booth. 1879 """ 1880 if onlykeys is None: 1881 onlykeys = d.keys() 1882 if byvalue: 1883 i = [(val, key) for (key, val) in d.items() if key in onlykeys] 1884 i.sort() 1885 if reverse: 1886 i.reverse() 1887 rvals = [val for (val, key) in i] 1888 rkeys = [key for (val, key) in i] 1889 else: 1890 # by key 1891 i = [(key, val) for (key, val) in d.items() if key in onlykeys] 1892 i.sort() 1893 if reverse: 1894 i.reverse() 1895 rvals = [val for (key, val) in i] 1896 rkeys = [key for (key, val) in i] 1897 return (rkeys, rvals)
1898
1899 -def sortedDictItems(d, byvalue=True, onlykeys=None, reverse=False):
1900 """Return list of (key, value) pairs of a dictionary, 1901 sorted by value (default) or key. 1902 Adapted from an original function by Duncan Booth. 1903 """ 1904 ks, vs = sortedDictLists(d, byvalue, onlykeys, reverse) 1905 return zip(ks,vs)
1906 1907 # ---------------------------------------------------------------------- 1908 1909 ## private versions of these utils (cannot import them from utils!) 1910 1911 # find intersection of two lists, sequences, etc.
1912 -def intersect(a, b):
1913 return filter(lambda e : e in b, a)
1914 1915 1916 # find remainder of two lists, sequences, etc., after intersection
1917 -def remain(a, b):
1918 return filter(lambda e : e not in b, a)
1919 1920 1921 # ---------------------------------------------------------------------- 1922 1923 # The Utility class may be abandoned in future versions.
1924 -class Utility(object):
1925 """ 1926 Utility abstract class for manipulating and analyzing dynamical systems. 1927 1928 Robert Clewley, March 2005. 1929 1930 Subclasses of Utility could include such things as continuation tools, 1931 dimension reduction tools, parameter estimation tools. 1932 """ 1933 pass
1934 1935 1936 # -------------------------------------------------------------------- 1937 # This section adapted from scipy.interpolate 1938 # -------------------------------------------------------------------- 1939 1940
1941 -class interpclass(object):
1942 """Abstract class for interpolators.""" 1943 interp_axis = -1 # used to set which is default interpolation
1944 # axis. DO NOT CHANGE OR CODE WILL BREAK. 1945 1946
1947 -class interp0d(interpclass):
1948 """Design of this class based on SciPy's interp1d""" 1949
1950 - def __init__(self, x, y, axis=-1, makecopy=0, bounds_error=1, 1951 fill_value=None):
1952 """Initialize a piecewise-constant interpolation class 1953 1954 Description: 1955 x and y are arrays of values used to approximate some function f: 1956 y = f(x) 1957 This class returns a function whose call method uses piecewise- 1958 constant interpolation to find the value of new points. 1959 1960 Inputs: 1961 x -- a 1d array of monotonically increasing real values. 1962 x cannot include duplicate values. (otherwise f is 1963 overspecified) 1964 y -- an nd array of real values. y's length along the 1965 interpolation axis must be equal to the length 1966 of x. 1967 axis -- specifies the axis of y along which to 1968 interpolate. Interpolation defaults to the last 1969 axis of y. (default: -1) 1970 makecopy -- If 1, the class makes internal copies of x and y. 1971 If 0, references to x and y are used. The default 1972 is NOT to copy. (default: 0) 1973 bounds_error -- If 1, an error is thrown any time interpolation 1974 is attempted on a value outside of the range 1975 of x (where extrapolation is necessary). 1976 If 0, out of bounds values are assigned the 1977 NaN (#INF) value. By default, an error is 1978 raised, although this is prone to change. 1979 (default: 1) 1980 """ 1981 self.datapoints = (array(x, float), array(y, float)) # RHC -- for access from PyDSTool 1982 self.type = float # RHC -- for access from PyDSTool 1983 self.axis = axis 1984 self.makecopy = makecopy # RHC -- renamed from copy to avoid nameclash 1985 self.bounds_error = bounds_error 1986 if fill_value is None: 1987 self.fill_value = NaN # RHC -- was: array(0.0) / array(0.0) 1988 else: 1989 self.fill_value = fill_value 1990 1991 # Check that both x and y are at least 1 dimensional. 1992 if len(shape(x)) == 0 or len(shape(y)) == 0: 1993 raise ValueError("x and y arrays must have at least one dimension.") 1994 # make a "view" of the y array that is rotated to the 1995 # interpolation axis. 1996 oriented_x = x 1997 oriented_y = swapaxes(y,self.interp_axis,axis) 1998 interp_axis = self.interp_axis 1999 len_x,len_y = shape(oriented_x)[interp_axis], \ 2000 shape(oriented_y)[interp_axis] 2001 if len_x != len_y: 2002 raise ValueError("x and y arrays must be equal in length along " 2003 "interpolation axis.") 2004 if len_x < 2 or len_y < 2: 2005 raise ValueError("x and y arrays must have more than 1 entry") 2006 self.x = array(oriented_x,copy=self.makecopy) 2007 self.y = array(oriented_y,copy=self.makecopy)
2008 2009
2010 - def __call__(self,x_new):
2011 """Find piecewise-constant interpolated y_new = <name>(x_new). 2012 2013 Inputs: 2014 x_new -- New independent variables. 2015 2016 Outputs: 2017 y_new -- Piecewise-constant interpolated values corresponding to x_new. 2018 """ 2019 # 1. Handle values in x_new that are outside of x. Throw error, 2020 # or return a list of mask array indicating the outofbounds values. 2021 # The behavior is set by the bounds_error variable. 2022 ## RHC -- was x_new = atleast_1d(x_new) 2023 x_new_1d = atleast_1d(x_new) 2024 out_of_bounds = self._check_bounds(x_new_1d) 2025 # 2. Find where in the orignal data, the values to interpolate 2026 # would be inserted. 2027 # Note: If x_new[n] = x[m], then m is returned by searchsorted. 2028 x_new_indices = searchsorted(self.x,x_new_1d) 2029 # 3. Clip x_new_indices so that they are within the range of 2030 # self.x indices and at least 1. Removes mis-interpolation 2031 # of x_new[n] = x[0] 2032 x_new_indices = clip(x_new_indices,1,len(self.x)-1).astype(int) 2033 # 4. Calculate the region that each x_new value falls in. 2034 lo = x_new_indices - 1; hi = x_new_indices 2035 2036 # !! take() should default to the last axis (IMHO) and remove 2037 # !! the extra argument. 2038 # 5. Calculate the actual value for each entry in x_new. 2039 y_lo = take(self.y,lo,axis=self.interp_axis) 2040 y_hi = take(self.y,hi,axis=self.interp_axis) 2041 y_new = (y_lo+y_hi)/2. 2042 # 6. Fill any values that were out of bounds with NaN 2043 # !! Need to think about how to do this efficiently for 2044 # !! mutli-dimensional Cases. 2045 yshape = y_new.shape 2046 y_new = y_new.ravel() 2047 new_shape = list(yshape) 2048 new_shape[self.interp_axis] = 1 2049 sec_shape = [1]*len(new_shape) 2050 sec_shape[self.interp_axis] = len(out_of_bounds) 2051 out_of_bounds.shape = sec_shape 2052 new_out = ones(new_shape)*out_of_bounds 2053 putmask(y_new, new_out.ravel(), self.fill_value) 2054 y_new.shape = yshape 2055 # Rotate the values of y_new back so that they correspond to the 2056 # correct x_new values. 2057 result = swapaxes(y_new,self.interp_axis,self.axis) 2058 try: 2059 len(x_new) 2060 return result 2061 except TypeError: 2062 return result[0] 2063 return result
2064 2065
2066 - def _check_bounds(self,x_new):
2067 # If self.bounds_error = 1, we raise an error if any x_new values 2068 # fall outside the range of x. Otherwise, we return an array indicating 2069 # which values are outside the boundary region. 2070 # !! Needs some work for multi-dimensional x !! 2071 below_bounds = less(x_new,self.x[0]) 2072 above_bounds = greater(x_new,self.x[-1]) 2073 # Note: sometrue has been redefined to handle length 0 arrays 2074 # !! Could provide more information about which values are out of bounds 2075 # RHC -- Changed these ValueErrors to PyDSTool_BoundsErrors 2076 if self.bounds_error and any(sometrue(below_bounds)): 2077 ## print "Input:", x_new 2078 ## print "Bound:", self.x[0] 2079 ## print "Difference input - bound:", x_new-self.x[0] 2080 raise PyDSTool_BoundsError(" A value in x_new is below the" 2081 " interpolation range.") 2082 if self.bounds_error and any(sometrue(above_bounds)): 2083 ## print "Input:", x_new 2084 ## print "Bound:", self.x[-1] 2085 ## print "Difference input - bound:", x_new-self.x[-1] 2086 raise PyDSTool_BoundsError(" A value in x_new is above the" 2087 " interpolation range.") 2088 # !! Should we emit a warning if some values are out of bounds. 2089 # !! matlab does not. 2090 out_of_bounds = logical_or(below_bounds,above_bounds) 2091 return out_of_bounds
2092 2093 2094 # RHC added
2095 - def __getstate__(self):
2096 d = copy(self.__dict__) 2097 # remove reference to Cfunc self.type 2098 d['type'] = _num_type2name[self.type] 2099 return d
2100 2101 # RHC added
2102 - def __setstate__(self, state):
2103 self.__dict__.update(state) 2104 # reinstate Cfunc self.type 2105 self.type = _num_name2type[self.type]
2106 2107 2108
2109 -class interp1d(interpclass): # RHC -- made this a new-style Python class
2110 - def __init__(self, x, y, kind='linear', axis=-1, 2111 makecopy = 0, bounds_error=1, fill_value=None):
2112 """Initialize a 1d piecewise-linear interpolation class 2113 2114 Description: 2115 x and y are arrays of values used to approximate some function f: 2116 y = f(x) 2117 This class returns a function whose call method uses linear 2118 interpolation to find the value of new points. 2119 2120 Inputs: 2121 x -- a 1d array of monotonically increasing real values. 2122 x cannot include duplicate values. (otherwise f is 2123 overspecified) 2124 y -- an nd array of real values. y's length along the 2125 interpolation axis must be equal to the length 2126 of x. 2127 kind -- specify the kind of interpolation: 'nearest', 'linear', 2128 'cubic', or 'spline' 2129 axis -- specifies the axis of y along which to 2130 interpolate. Interpolation defaults to the last 2131 axis of y. (default: -1) 2132 makecopy -- If 1, the class makes internal copies of x and y. 2133 If 0, references to x and y are used. The default 2134 is NOT to copy. (default: 0) 2135 bounds_error -- If 1, an error is thrown any time interpolation 2136 is attempted on a value outside of the range 2137 of x (where extrapolation is necessary). 2138 If 0, out of bounds values are assigned the 2139 NaN (#INF) value. By default, an error is 2140 raised, although this is prone to change. 2141 (default: 1) 2142 """ 2143 self.datapoints = (array(x, float), array(y, float)) # RHC -- for access from PyDSTool 2144 self.type = float # RHC -- for access from PyDSTool 2145 self.axis = axis 2146 self.makecopy = makecopy # RHC -- renamed from copy to avoid nameclash 2147 self.bounds_error = bounds_error 2148 if fill_value is None: 2149 self.fill_value = NaN # RHC -- was: array(0.0) / array(0.0) 2150 else: 2151 self.fill_value = fill_value 2152 2153 if kind != 'linear': 2154 raise NotImplementedError("Only linear supported for now. " 2155 "Use fitpack routines for other types.") 2156 2157 # Check that both x and y are at least 1 dimensional. 2158 if len(shape(x)) == 0 or len(shape(y)) == 0: 2159 raise ValueError("x and y arrays must have at least one dimension.") 2160 # make a "view" of the y array that is rotated to the 2161 # interpolation axis. 2162 oriented_x = x 2163 oriented_y = swapaxes(y,self.interp_axis,axis) 2164 interp_axis = self.interp_axis 2165 len_x,len_y = shape(oriented_x)[interp_axis], \ 2166 shape(oriented_y)[interp_axis] 2167 if len_x != len_y: 2168 raise ValueError("x and y arrays must be equal in length along " 2169 "interpolation axis.") 2170 if len_x < 2 or len_y < 2: 2171 raise ValueError("x and y arrays must have more than 1 entry") 2172 self.x = array(oriented_x,copy=self.makecopy) 2173 self.y = array(oriented_y,copy=self.makecopy)
2174 2175
2176 - def __call__(self,x_new):
2177 """Find linearly interpolated y_new = <name>(x_new). 2178 2179 Inputs: 2180 x_new -- New independent variables. 2181 2182 Outputs: 2183 y_new -- Linearly interpolated values corresponding to x_new. 2184 """ 2185 # 1. Handle values in x_new that are outside of x. Throw error, 2186 # or return a list of mask array indicating the outofbounds values. 2187 # The behavior is set by the bounds_error variable. 2188 ## RHC -- was x_new = atleast_1d(x_new) 2189 x_new_1d = atleast_1d(x_new) 2190 out_of_bounds = self._check_bounds(x_new_1d) 2191 # 2. Find where in the orignal data, the values to interpolate 2192 # would be inserted. 2193 # Note: If x_new[n] = x[m], then m is returned by searchsorted. 2194 x_new_indices = searchsorted(self.x,x_new_1d) 2195 # 3. Clip x_new_indices so that they are within the range of 2196 # self.x indices and at least 1. Removes mis-interpolation 2197 # of x_new[n] = x[0] 2198 x_new_indices = clip(x_new_indices,1,len(self.x)-1).astype(int) 2199 # 4. Calculate the slope of regions that each x_new value falls in. 2200 lo = x_new_indices - 1; hi = x_new_indices 2201 2202 # !! take() should default to the last axis (IMHO) and remove 2203 # !! the extra argument. 2204 x_lo = take(self.x,lo,axis=self.interp_axis) 2205 x_hi = take(self.x,hi,axis=self.interp_axis) 2206 y_lo = take(self.y,lo,axis=self.interp_axis) 2207 y_hi = take(self.y,hi,axis=self.interp_axis) 2208 slope = (y_hi-y_lo)/(x_hi-x_lo) 2209 # 5. Calculate the actual value for each entry in x_new. 2210 y_new = slope*(x_new_1d-x_lo) + y_lo 2211 # 6. Fill any values that were out of bounds with NaN 2212 # !! Need to think about how to do this efficiently for 2213 # !! mutli-dimensional Cases. 2214 yshape = y_new.shape 2215 y_new = y_new.ravel() 2216 new_shape = list(yshape) 2217 new_shape[self.interp_axis] = 1 2218 sec_shape = [1]*len(new_shape) 2219 sec_shape[self.interp_axis] = len(out_of_bounds) 2220 out_of_bounds.shape = sec_shape 2221 new_out = ones(new_shape)*out_of_bounds 2222 putmask(y_new, new_out.ravel(), self.fill_value) 2223 y_new.shape = yshape 2224 # Rotate the values of y_new back so that they correspond to the 2225 # correct x_new values. 2226 result = swapaxes(y_new,self.interp_axis,self.axis) 2227 try: 2228 len(x_new) 2229 return result 2230 except TypeError: 2231 return result[0] 2232 return result
2233 2234
2235 - def _check_bounds(self,x_new):
2236 # If self.bounds_error = 1, we raise an error if any x_new values 2237 # fall outside the range of x. Otherwise, we return an array indicating 2238 # which values are outside the boundary region. 2239 # !! Needs some work for multi-dimensional x !! 2240 below_bounds = less(x_new,self.x[0]) 2241 above_bounds = greater(x_new,self.x[-1]) 2242 # Note: sometrue has been redefined to handle length 0 arrays 2243 # !! Could provide more information about which values are out of bounds 2244 # RHC -- Changed these ValueErrors to PyDSTool_BoundsErrors 2245 if self.bounds_error and any(sometrue(below_bounds)): 2246 ## print "Input:", x_new 2247 ## print "Bound:", self.x[0] 2248 ## print "Difference input - bound:", x_new-self.x[0] 2249 raise PyDSTool_BoundsError("A value in x_new is below the" 2250 " interpolation range.") 2251 if self.bounds_error and any(sometrue(above_bounds)): 2252 ## print "Input:", x_new 2253 ## print "Bound:", self.x[-1] 2254 ## print "Difference input - bound:", x_new-self.x[-1] 2255 raise PyDSTool_BoundsError("A value in x_new is above the" 2256 " interpolation range.") 2257 # !! Should we emit a warning if some values are out of bounds. 2258 # !! matlab does not. 2259 out_of_bounds = logical_or(below_bounds,above_bounds) 2260 return out_of_bounds
2261 2262 2263 # RHC added
2264 - def __getstate__(self):
2265 d = copy(self.__dict__) 2266 # remove reference to Cfunc self.type 2267 d['type'] = _num_type2name[self.type] 2268 return d
2269 2270 # RHC added
2271 - def __setstate__(self, state):
2272 self.__dict__.update(state) 2273 # reinstate Cfunc self.type 2274 self.type = _num_name2type[self.type]
2275 2276 2277 # The following interpolation functions were written and (c) Anne 2278 # Archibald. 2279
2280 -class KroghInterpolator(object):
2281 """The interpolating polynomial for a set of points 2282 2283 Constructs a polynomial that passes through a given set of points, 2284 optionally with specified derivatives at those points. 2285 Allows evaluation of the polynomial and all its derivatives. 2286 For reasons of numerical stability, this function does not compute 2287 the coefficients of the polynomial, although they can be obtained 2288 by evaluating all the derivatives. 2289 2290 Be aware that the algorithms implemented here are not necessarily 2291 the most numerically stable known. Moreover, even in a world of 2292 exact computation, unless the x coordinates are chosen very 2293 carefully - Chebyshev zeros (e.g. cos(i*pi/n)) are a good choice - 2294 polynomial interpolation itself is a very ill-conditioned process 2295 due to the Runge phenomenon. In general, even with well-chosen 2296 x values, degrees higher than about thirty cause problems with 2297 numerical instability in this code. 2298 2299 Based on Krogh 1970, "Efficient Algorithms for Polynomial Interpolation 2300 and Numerical Differentiation" 2301 """
2302 - def __init__(self, xi, yi):
2303 """Construct an interpolator passing through the specified points 2304 2305 The polynomial passes through all the pairs (xi,yi). One may additionally 2306 specify a number of derivatives at each point xi; this is done by 2307 repeating the value xi and specifying the derivatives as successive 2308 yi values. 2309 2310 Parameters 2311 ---------- 2312 xi : array-like, length N 2313 known x-coordinates 2314 yi : array-like, N by R 2315 known y-coordinates, interpreted as vectors of length R, 2316 or scalars if R=1 2317 2318 Example 2319 ------- 2320 To produce a polynomial that is zero at 0 and 1 and has 2321 derivative 2 at 0, call 2322 2323 >>> KroghInterpolator([0,0,1],[0,2,0]) 2324 """ 2325 self.xi = npy.asarray(xi) 2326 self.yi = npy.asarray(yi) 2327 if len(self.yi.shape)==1: 2328 self.vector_valued = False 2329 self.yi = self.yi[:,npy.newaxis] 2330 elif len(self.yi.shape)>2: 2331 raise ValueError, "y coordinates must be either scalars or vectors" 2332 else: 2333 self.vector_valued = True 2334 2335 n = len(xi) 2336 self.n = n 2337 nn, r = self.yi.shape 2338 if nn!=n: 2339 raise ValueError, "%d x values provided and %d y values; must be equal" % (n, nn) 2340 self.r = r 2341 2342 c = npy.zeros((n+1,r)) 2343 c[0] = yi[0] 2344 Vk = npy.zeros((n,r)) 2345 for k in xrange(1,n): 2346 s = 0 2347 while s<=k and xi[k-s]==xi[k]: 2348 s += 1 2349 s -= 1 2350 Vk[0] = yi[k]/float(spy.factorial(s)) 2351 for i in xrange(k-s): 2352 assert xi[i]!=xi[k] 2353 if s==0: 2354 Vk[i+1] = (c[i]-Vk[i])/(xi[i]-xi[k]) 2355 else: 2356 Vk[i+1] = (Vk[i+1]-Vk[i])/(xi[i]-xi[k]) 2357 c[k] = Vk[k-s] 2358 self.c = c
2359
2360 - def __call__(self,x):
2361 """Evaluate the polynomial at the point x 2362 2363 Parameters 2364 ---------- 2365 x : scalar or array-like of length N 2366 2367 Returns 2368 ------- 2369 y : scalar, array of length R, array of length N, or array of length N by R 2370 If x is a scalar, returns either a vector or a scalar depending on 2371 whether the interpolator is vector-valued or scalar-valued. 2372 If x is a vector, returns a vector of values. 2373 """ 2374 if npy.isscalar(x): 2375 scalar = True 2376 m = 1 2377 else: 2378 scalar = False 2379 m = len(x) 2380 x = npy.asarray(x) 2381 2382 n = self.n 2383 pi = 1 2384 p = npy.zeros((m,self.r)) 2385 p += self.c[0,npy.newaxis,:] 2386 for k in xrange(1,n): 2387 w = x - self.xi[k-1] 2388 pi = w*pi 2389 p = p + npy.multiply.outer(pi,self.c[k]) 2390 if not self.vector_valued: 2391 if scalar: 2392 return p[0,0] 2393 else: 2394 return p[:,0] 2395 else: 2396 if scalar: 2397 return p[0] 2398 else: 2399 return p
2400
2401 - def derivatives(self,x,der=None):
2402 """Evaluate many derivatives of the polynomial at the point x 2403 2404 Produce an array of all derivative values at the point x. 2405 2406 Parameters 2407 ---------- 2408 x : scalar or array-like of length N 2409 Point or points at which to evaluate the derivatives 2410 der : None or integer 2411 How many derivatives to extract; None for all potentially 2412 nonzero derivatives (that is a number equal to the number 2413 of points). This number includes the function value as 0th 2414 derivative. 2415 Returns 2416 ------- 2417 d : array 2418 If the interpolator's values are R-dimensional then the 2419 returned array will be der by N by R. If x is a scalar, 2420 the middle dimension will be dropped; if R is 1 then the 2421 last dimension will be dropped. 2422 2423 Example 2424 ------- 2425 >>> KroghInterpolator([0,0,0],[1,2,3]).derivatives(0) 2426 array([1.0,2.0,3.0]) 2427 >>> KroghInterpolator([0,0,0],[1,2,3]).derivatives([0,0]) 2428 array([[1.0,1.0], 2429 [2.0,2.0], 2430 [3.0,3.0]]) 2431 """ 2432 if npy.isscalar(x): 2433 scalar = True 2434 m = 1 2435 else: 2436 scalar = False 2437 m = len(x) 2438 x = npy.asarray(x) 2439 2440 n = self.n 2441 r = self.r 2442 2443 if der is None: 2444 der = self.n 2445 dern = min(self.n,der) 2446 pi = npy.zeros((n,m)) 2447 w = npy.zeros((n,m)) 2448 pi[0] = 1 2449 p = npy.zeros((m,self.r)) 2450 p += self.c[0,npy.newaxis,:] 2451 2452 for k in xrange(1,n): 2453 w[k-1] = x - self.xi[k-1] 2454 pi[k] = w[k-1]*pi[k-1] 2455 p += npy.multiply.outer(pi[k],self.c[k]) 2456 2457 cn = npy.zeros((max(der,n+1),m,r)) 2458 cn[:n+1,...] += self.c[:n+1,npy.newaxis,:] 2459 cn[0] = p 2460 for k in xrange(1,n): 2461 for i in xrange(1,n-k+1): 2462 pi[i] = w[k+i-1]*pi[i-1]+pi[i] 2463 cn[k] = cn[k]+pi[i,:,npy.newaxis]*cn[k+i] 2464 cn[k]*=factorial(k) 2465 2466 cn[n,...] = 0 2467 if not self.vector_valued: 2468 if scalar: 2469 return cn[:der,0,0] 2470 else: 2471 return cn[:der,:,0] 2472 else: 2473 if scalar: 2474 return cn[:der,0] 2475 else: 2476 return cn[:der]
2477 - def derivative(self,x,der):
2478 """Evaluate one derivative of the polynomial at the point x 2479 2480 Parameters 2481 ---------- 2482 x : scalar or array-like of length N 2483 Point or points at which to evaluate the derivatives 2484 der : None or integer 2485 Which derivative to extract. This number includes the 2486 function value as 0th derivative. 2487 Returns 2488 ------- 2489 d : array 2490 If the interpolator's values are R-dimensional then the 2491 returned array will be N by R. If x is a scalar, 2492 the middle dimension will be dropped; if R is 1 then the 2493 last dimension will be dropped. 2494 2495 Notes 2496 ----- 2497 This is computed by evaluating all derivatives up to the desired 2498 one and then discarding the rest. 2499 """ 2500 return self.derivatives(x,der=der+1)[der]
2501 2502
2503 -class BarycentricInterpolator(object):
2504 """The interpolating polynomial for a set of points 2505 2506 Constructs a polynomial that passes through a given set of points. 2507 Allows evaluation of the polynomial, efficient changing of the y 2508 values to be interpolated, and updating by adding more x values. 2509 For reasons of numerical stability, this function does not compute 2510 the coefficients of the polynomial. 2511 2512 This class uses a "barycentric interpolation" method that treats 2513 the problem as a special case of rational function interpolation. 2514 This algorithm is quite stable, numerically, but even in a world of 2515 exact computation, unless the x coordinates are chosen very 2516 carefully - Chebyshev zeros (e.g. cos(i*pi/n)) are a good choice - 2517 polynomial interpolation itself is a very ill-conditioned process 2518 due to the Runge phenomenon. 2519 2520 Based on Berrut and Trefethen 2004, "Barycentric Lagrange Interpolation". 2521 """
2522 - def __init__(self, xi, yi=None):
2523 """Construct an object capable of interpolating functions sampled at xi 2524 2525 The values yi need to be provided before the function is evaluated, 2526 but none of the preprocessing depends on them, so rapid updates 2527 are possible. 2528 2529 Parameters 2530 ---------- 2531 xi : array-like of length N 2532 The x coordinates of the points the polynomial should pass through 2533 yi : array-like N by R or None 2534 The y coordinates of the points the polynomial should pass through; 2535 if R>1 the polynomial is vector-valued. If None the y values 2536 will be supplied later. 2537 """ 2538 self.n = len(xi) 2539 self.xi = npy.asarray(xi) 2540 if yi is not None and len(yi)!=len(self.xi): 2541 raise ValueError, "yi dimensions do not match xi dimensions" 2542 self.set_yi(yi) 2543 self.wi = npy.zeros(self.n) 2544 self.wi[0] = 1 2545 for j in xrange(1,self.n): 2546 self.wi[:j]*=(self.xi[j]-self.xi[:j]) 2547 self.wi[j] = npy.multiply.reduce(self.xi[:j]-self.xi[j]) 2548 self.wi**=-1
2549
2550 - def set_yi(self, yi):
2551 """Update the y values to be interpolated 2552 2553 The barycentric interpolation algorithm requires the calculation 2554 of weights, but these depend only on the xi. The yi can be changed 2555 at any time. 2556 2557 Parameters 2558 ---------- 2559 yi : array-like N by R 2560 The y coordinates of the points the polynomial should pass through; 2561 if R>1 the polynomial is vector-valued. If None the y values 2562 will be supplied later. 2563 """ 2564 if yi is None: 2565 self.yi = None 2566 return 2567 yi = npy.asarray(yi) 2568 if len(yi.shape)==1: 2569 self.vector_valued = False 2570 yi = yi[:,npy.newaxis] 2571 elif len(yi.shape)>2: 2572 raise ValueError, "y coordinates must be either scalars or vectors" 2573 else: 2574 self.vector_valued = True 2575 2576 n, r = yi.shape 2577 if n!=len(self.xi): 2578 raise ValueError, "yi dimensions do not match xi dimensions" 2579 self.yi = yi 2580 self.r = r
2581 2582
2583 - def add_xi(self, xi, yi=None):
2584 """Add more x values to the set to be interpolated 2585 2586 The barycentric interpolation algorithm allows easy updating by 2587 adding more points for the polynomial to pass through. 2588 2589 Parameters 2590 ---------- 2591 xi : array-like of length N1 2592 The x coordinates of the points the polynomial should pass through 2593 yi : array-like N1 by R or None 2594 The y coordinates of the points the polynomial should pass through; 2595 if R>1 the polynomial is vector-valued. If None the y values 2596 will be supplied later. The yi should be specified if and only if 2597 the interpolator has y values specified. 2598 """ 2599 if yi is not None: 2600 if self.yi is None: 2601 raise ValueError, "No previous yi value to update!" 2602 yi = npy.asarray(yi) 2603 if len(yi.shape)==1: 2604 if self.vector_valued: 2605 raise ValueError, "Cannot extend dimension %d y vectors with scalars" % self.r 2606 yi = yi[:,npy.newaxis] 2607 elif len(yi.shape)>2: 2608 raise ValueError, "y coordinates must be either scalars or vectors" 2609 else: 2610 n, r = yi.shape 2611 if r!=self.r: 2612 raise ValueError, "Cannot extend dimension %d y vectors with dimension %d y vectors" % (self.r, r) 2613 2614 self.yi = npy.vstack((self.yi,yi)) 2615 else: 2616 if self.yi is not None: 2617 raise ValueError, "No update to yi provided!" 2618 old_n = self.n 2619 self.xi = npy.concatenate((self.xi,xi)) 2620 self.n = len(self.xi) 2621 self.wi**=-1 2622 old_wi = self.wi 2623 self.wi = npy.zeros(self.n) 2624 self.wi[:old_n] = old_wi 2625 for j in xrange(old_n,self.n): 2626 self.wi[:j]*=(self.xi[j]-self.xi[:j]) 2627 self.wi[j] = npy.multiply.reduce(self.xi[:j]-self.xi[j]) 2628 self.wi**=-1
2629
2630 - def __call__(self, x):
2631 """Evaluate the interpolating polynomial at the points x 2632 2633 Parameters 2634 ---------- 2635 x : scalar or array-like of length M 2636 2637 Returns 2638 ------- 2639 y : scalar or array-like of length R or length M or M by R 2640 The shape of y depends on the shape of x and whether the 2641 interpolator is vector-valued or scalar-valued. 2642 2643 Notes 2644 ----- 2645 Currently the code computes an outer product between x and the 2646 weights, that is, it constructs an intermediate array of size 2647 N by M, where N is the degree of the polynomial. 2648 """ 2649 scalar = npy.isscalar(x) 2650 x = npy.atleast_1d(x) 2651 c = npy.subtract.outer(x,self.xi) 2652 z = c==0 2653 c[z] = 1 2654 c = self.wi/c 2655 p = npy.dot(c,self.yi)/npy.sum(c,axis=-1)[:,npy.newaxis] 2656 i, j = npy.nonzero(z) 2657 p[i] = self.yi[j] 2658 if not self.vector_valued: 2659 if scalar: 2660 return p[0,0] 2661 else: 2662 return p[:,0] 2663 else: 2664 if scalar: 2665 return p[0] 2666 else: 2667 return p
2668 2669 # RHC - made a sub-class of interpclass
2670 -class PiecewisePolynomial(interpclass):
2671 """Piecewise polynomial curve specified by points and derivatives. 2672 2673 This class represents a curve that is a piecewise polynomial. It 2674 passes through a list of points and has specified derivatives at 2675 each point. The degree of the polynomial may very from segment to 2676 segment, as may the number of derivatives available. The degree 2677 should not exceed about thirty. 2678 2679 Appending points to the end of the curve is efficient. 2680 """
2681 - def __init__(self, xi, yi, orders=None, direction=None):
2682 """Construct a piecewise polynomial 2683 2684 Parameters 2685 ---------- 2686 xi : array-like of length N 2687 a sorted list of x-coordinates 2688 yi : list of lists of length N 2689 yi[i] is the list of derivatives known at xi[i] 2690 orders : list of integers, or integer 2691 a list of polynomial orders, or a single universal order 2692 direction : {None, 1, -1} 2693 indicates whether the xi are increasing or decreasing 2694 +1 indicates increasing 2695 -1 indicates decreasing 2696 None indicates that it should be deduced from the first two xi 2697 2698 Notes 2699 ----- 2700 If orders is None, or orders[i] is None, then the degree of the 2701 polynomial segment is exactly the degree required to match all i 2702 available derivatives at both endpoints. If orders[i] is not None, 2703 then some derivatives will be ignored. The code will try to use an 2704 equal number of derivatives from each end; if the total number of 2705 derivatives needed is odd, it will prefer the rightmost endpoint. If 2706 not enough derivatives are available, an exception is raised. 2707 """ 2708 # RHC added datapoints for use by PyDSTool 2709 # don't store any derivative information in datapoints 2710 self.datapoints = (array(xi, float), array(yi[:,0], float)) 2711 self.type = float # RHC -- for access from PyDSTool 2712 yi0 = npy.asarray(yi[0]) 2713 if len(yi0.shape)==2: 2714 self.vector_valued = True 2715 self.r = yi0.shape[1] 2716 elif len(yi0.shape)==1: 2717 self.vector_valued = False 2718 self.r = 1 2719 else: 2720 raise ValueError, "Each derivative must be a vector, not a higher-rank array" 2721 2722 self.xi = [xi[0]] 2723 self.yi = [yi0] 2724 self.n = 1 2725 2726 self.direction = direction 2727 self.orders = [] 2728 self.polynomials = [] 2729 self.extend(xi[1:],yi[1:],orders)
2730
2731 - def _make_polynomial(self,x1,y1,x2,y2,order,direction):
2732 """Construct the interpolating polynomial object 2733 2734 Deduces the number of derivatives to match at each end 2735 from order and the number of derivatives available. If 2736 possible it uses the same number of derivatives from 2737 each end; if the number is odd it tries to take the 2738 extra one from y2. In any case if not enough derivatives 2739 are available at one end or another it draws enough to 2740 make up the total from the other end. 2741 """ 2742 n = order+1 2743 n1 = min(n//2,len(y1)) 2744 n2 = min(n-n1,len(y2)) 2745 n1 = min(n-n2,len(y1)) 2746 if n1+n2!=n: 2747 raise ValueError, "Point %g has %d derivatives, point %g has %d derivatives, but order %d requested" % (x1, len(y1), x2, len(y2), order) 2748 assert n1<=len(y1) 2749 assert n2<=len(y2) 2750 2751 xi = npy.zeros(n) 2752 if self.vector_valued: 2753 yi = npy.zeros((n,self.r)) 2754 else: 2755 yi = npy.zeros((n,)) 2756 2757 xi[:n1] = x1 2758 yi[:n1] = y1[:n1] 2759 xi[n1:] = x2 2760 yi[n1:] = y2[:n2] 2761 2762 return KroghInterpolator(xi,yi)
2763
2764 - def append(self, xi, yi, order=None):
2765 """Append a single point with derivatives to the PiecewisePolynomial 2766 2767 Parameters 2768 ---------- 2769 xi : float 2770 yi : array-like 2771 yi is the list of derivatives known at xi 2772 order : integer or None 2773 a polynomial order, or instructions to use the highest 2774 possible order 2775 """ 2776 2777 yi = npy.asarray(yi) 2778 if self.vector_valued: 2779 if (len(yi.shape)!=2 or yi.shape[1]!=self.r): 2780 raise ValueError, "Each derivative must be a vector of length %d" % self.r 2781 else: 2782 if len(yi.shape)!=1: 2783 raise ValueError, "Each derivative must be a scalar" 2784 2785 if self.direction is None: 2786 self.direction = npy.sign(xi-self.xi[-1]) 2787 elif (xi-self.xi[-1])*self.direction < 0: 2788 raise ValueError, "x coordinates must be in the %d direction: %s" % (self.direction, self.xi) 2789 2790 self.xi.append(xi) 2791 self.yi.append(yi) 2792 2793 2794 if order is None: 2795 n1 = len(self.yi[-2]) 2796 n2 = len(self.yi[-1]) 2797 n = n1+n2 2798 order = n-1 2799 2800 self.orders.append(order) 2801 self.polynomials.append(self._make_polynomial( 2802 self.xi[-2], self.yi[-2], 2803 self.xi[-1], self.yi[-1], 2804 order, self.direction)) 2805 self.n += 1
2806 2807
2808 - def extend(self, xi, yi, orders=None):
2809 """Extend the PiecewisePolynomial by a list of points 2810 2811 Parameters 2812 ---------- 2813 xi : array-like of length N1 2814 a sorted list of x-coordinates 2815 yi : list of lists of length N1 2816 yi[i] is the list of derivatives known at xi[i] 2817 orders : list of integers, or integer 2818 a list of polynomial orders, or a single universal order 2819 direction : {None, 1, -1} 2820 indicates whether the xi are increasing or decreasing 2821 +1 indicates increasing 2822 -1 indicates decreasing 2823 None indicates that it should be deduced from the first two xi 2824 """ 2825 2826 for i in xrange(len(xi)): 2827 if orders is None or npy.isscalar(orders): 2828 self.append(xi[i],yi[i],orders) 2829 else: 2830 self.append(xi[i],yi[i],orders[i])
2831
2832 - def __call__(self, x):
2833 """Evaluate the piecewise polynomial 2834 2835 Parameters 2836 ---------- 2837 x : scalar or array-like of length N 2838 2839 Returns 2840 ------- 2841 y : scalar or array-like of length R or length N or N by R 2842 """ 2843 if npy.isscalar(x): 2844 pos = npy.clip(npy.searchsorted(self.xi, x) - 1, 0, self.n-2) 2845 y = self.polynomials[pos](x) 2846 else: 2847 x = npy.asarray(x) 2848 m = len(x) 2849 pos = npy.clip(npy.searchsorted(self.xi, x) - 1, 0, self.n-2) 2850 if self.vector_valued: 2851 y = npy.zeros((m,self.r)) 2852 else: 2853 y = npy.zeros(m) 2854 for i in xrange(self.n-1): 2855 c = pos==i 2856 y[c] = self.polynomials[i](x[c]) 2857 return y
2858
2859 - def derivative(self, x, der):
2860 """Evaluate a derivative of the piecewise polynomial 2861 2862 Parameters 2863 ---------- 2864 x : scalar or array-like of length N 2865 der : integer 2866 which single derivative to extract 2867 2868 Returns 2869 ------- 2870 y : scalar or array-like of length R or length N or N by R 2871 2872 Notes 2873 ----- 2874 This currently computes all derivatives of the curve segment 2875 containing each x but returns only one. This is because the 2876 number of nonzero derivatives that a segment can have depends 2877 on the degree of the segment, which may vary. 2878 """ 2879 return self.derivatives(x,der=der+1)[der]
2880
2881 - def derivatives(self, x, der):
2882 """Evaluate a derivative of the piecewise polynomial 2883 2884 Parameters 2885 ---------- 2886 x : scalar or array-like of length N 2887 der : integer 2888 how many derivatives (including the function value as 2889 0th derivative) to extract 2890 2891 Returns 2892 ------- 2893 y : array-like of shape der by R or der by N or der by N by R 2894 2895 """ 2896 if npy.isscalar(x): 2897 pos = npy.clip(npy.searchsorted(self.xi, x) - 1, 0, self.n-2) 2898 y = self.polynomials[pos].derivatives(x,der=der) 2899 else: 2900 x = npy.asarray(x) 2901 m = len(x) 2902 pos = npy.clip(npy.searchsorted(self.xi, x) - 1, 0, self.n-2) 2903 if self.vector_valued: 2904 y = npy.zeros((der,m,self.r)) 2905 else: 2906 y = npy.zeros((der,m)) 2907 for i in xrange(self.n-1): 2908 c = pos==i 2909 y[:,c] = self.polynomials[i].derivatives(x[c],der=der) 2910 return y
2911 # FIXME: provide multiderivative finder 2912 2913 # RHC added
2914 - def __getstate__(self):
2915 d = copy(self.__dict__) 2916 # remove reference to Cfunc self.type 2917 d['type'] = _num_type2name[self.type] 2918 return d
2919 2920 # RHC added
2921 - def __setstate__(self, state):
2922 self.__dict__.update(state) 2923 # reinstate Cfunc self.type 2924 self.type = _num_name2type[self.type]
2925 2926 # -------------------------------------------------------------------- 2927
2928 -def simple_bisection(tlo, thi, f, tol, imax=100):
2929 sol = None 2930 flo = f(tlo) 2931 fhi = f(thi) 2932 i = 1 2933 while i <= imax: 2934 d = (thi - tlo)/2. 2935 p = tlo + d 2936 if d < tol: 2937 sol = p 2938 break 2939 fp = f(p) 2940 if fp == 0: 2941 sol = p 2942 break 2943 i += 1 2944 if fp*flo > 0: 2945 tlo = p 2946 flo = fp 2947 else: 2948 thi = p 2949 if i == imax: 2950 sol = p 2951 return sol
2952 2953 # Function fitting tools 2954
2955 -class fit_function(object):
2956 """Abstract super-class for fitting explicit functions to 1D arrays of data 2957 using least squares. 2958 2959 xs -- independent variable data 2960 ys -- dependent variable data 2961 pars_ic -- initial values defining the function 2962 2963 Optional algorithmic parameters to minpack.leastsq can be passed in the 2964 algpars argument: e.g., 2965 ftol -- Relative error desired in the sum of squares (default 1e-6). 2966 xtol -- Relative error desired in the approximate solution (default 1e-6). 2967 gtol -- Orthogonality desired between the function vector 2968 and the columns of the Jacobian (default 1e-8). 2969 2970 Other parameters may be used for concrete sub-classes. Pass these as a dict 2971 or args object in the opts argument. 2972 2973 Returns an args object with attributes: 2974 2975 ys_fit -- the fitted y values corresponding to the given x data, 2976 pars_fit -- the function parameters at the fit 2977 info -- diagnostic feedback from the leastsq algorithm 2978 results -- dictionary of other function specific information (such as peak 2979 position) 2980 """ 2981
2982 - def __init__(self, pars_ic=None, algpars=None, opts=None, 2983 verbose=False):
2984 # defaults 2985 self.algpars = args(ftol=1e-8, xtol=1e-6, gtol=1e-8, maxfev=100) 2986 if algpars is not None: 2987 self.algpars.update(algpars) 2988 self.verbose = verbose 2989 self.pars_ic = pars_ic 2990 if hasattr(opts, 'weight'): 2991 self.weight = opts.weight 2992 else: 2993 self.weight = 1
2994
2995 - def fn(self, x, *pars):
2996 raise NotImplementedError("Override in a concrete sub-class")
2997
2998 - def _do_fit(self, constraint, xs, ys, pars_ic):
2999 xs = asarray(xs) 3000 ys = asarray(ys) 3001 weight = self.weight 3002 3003 if constraint is None: 3004 if self.verbose: 3005 def res_fn(p): 3006 print "\n",p 3007 r = self.fn(xs, *p) - ys 3008 print "Residual = %f"%norm(r*weight) 3009 return r*weight
3010 else: 3011 def res_fn(p): 3012 r = self.fn(xs, *p) - ys 3013 return r*weight
3014 else: 3015 if self.verbose: 3016 def res_fn(p): 3017 print "\n",p 3018 r = npy.concatenate((constraint(*p), (self.fn(xs, *p) - ys)*weight)) 3019 print "Residual = %f"%norm(r) 3020 return r 3021 else: 3022 def res_fn(p): 3023 return npy.concatenate((constraint(*p), (self.fn(xs, *p) - ys)*weight)) 3024 3025 try: 3026 res = minpack.leastsq(res_fn, pars_ic, 3027 full_output = True, 3028 ftol = self.algpars.ftol, 3029 xtol = self.algpars.xtol, 3030 gtol = self.algpars.gtol, 3031 maxfev = self.algpars.maxfev) 3032 except: 3033 print "Error at parameters", pars_ic 3034 raise 3035 if self.verbose: 3036 print "Result: ", res 3037 return res 3038
3039 - def fit(self, xs, ys, pars_ic=None, opts=None):
3040 raise NotImplementedError("Override in a concrete sub-class")
3041 3042
3043 -class fit_quadratic(fit_function):
3044 """Fit a quadratic function y=a*x^2+b*x+c to the (x,y) array data. 3045 If initial parameter values = (a,b,c) are not given, the values 3046 (1,1,0) will be used. 3047 3048 If peak_constraint is a tuple of values (x_index, y_value, weight_x, 3049 weight_y) for the approximate position of a turning point in the data, 3050 then this will be used as a soft constraint in the fit. 3051 3052 result.peak is a (xpeak, ypeak) pair. 3053 result.f is the fitted function (accepts x values). 3054 """ 3055
3056 - def fn(self, x, a, b, c):
3057 return a*x**2+b*x+c
3058
3059 - def fit(self, xs, ys, pars_ic=None, opts=None):
3060 try: 3061 peak_constraint = opts.peak_constraint 3062 except AttributeError: 3063 peak_constraint = None 3064 3065 if pars_ic is None: 3066 if self.pars_ic is None: 3067 pars_ic = array([1.,1.,0.]) 3068 else: 3069 pars_ic = self.pars_ic 3070 3071 if peak_constraint is None: 3072 constraint = None 3073 else: 3074 x_index, y_value, weight_x, weight_y = peak_constraint 3075 def constraint(a,b,c): 3076 return array([weight_y*(self.fn(xs[x_index],a,b,c)-y_value), 3077 weight_x*(xs[x_index]+b/(2*a))])
3078 res = self._do_fit(constraint, xs, ys, pars_ic) 3079 sol = res[0] 3080 a,b,c = sol 3081 def f(x): 3082 return a*x**2+b*x+c
3083 ys_fit = f(xs) 3084 xpeak = -b/(2*a) 3085 ypeak = f(xpeak) 3086 return args(ys_fit=ys_fit, pars_fit=(a,b,c), info=res, 3087 results=args(peak=(xpeak, ypeak), 3088 f=f)) 3089
3090 -class fit_quadratic_at_vertex(fit_function):
3091 """Fit a quadratic function y=a*(x+h)**2+k to the (x,y) array data, 3092 constrained to have a vertex at (h, k), leaving only the free parameter 3093 a for the curvature. (h, k) is specified through the peak_constraint 3094 option in the initialization argument 'opts'. 3095 3096 If initial parameter value = a is not given, the value 1 will be used. 3097 3098 result.peak is a (xpeak, ypeak) pair, but corresponds to (h,k). 3099 result.f is the fitted function (accepts x values). 3100 """ 3101
3102 - def fn(self, x, a):
3103 return a*(x+self.h)**2+self.k
3104
3105 - def fit(self, xs, ys, pars_ic=None, opts=None):
3106 self.h, self.k = opts.peak_constraint 3107 if pars_ic is None: 3108 if self.pars_ic is None: 3109 pars_ic = (1,) 3110 else: 3111 pars_ic = (self.pars_ic,) 3112 3113 res = self._do_fit(None, xs, ys, pars_ic) 3114 sol = res[0] 3115 a = sol 3116 def f(x): 3117 return a*(x+self.h)**2+self.k
3118 ys_fit = f(xs) 3119 return args(ys_fit=ys_fit, pars_fit=a, info=res, 3120 results=args(peak=(self.h, self.k), 3121 f=f))
3122
3123 -class fit_cubic(fit_function):
3124 """Fit a cubic function y=a*x^3+b*x^2+c*x+d to the (x,y) array data. 3125 If initial parameter values = (a,b,c,d) are not given, the values 3126 (1,1,1,0) will be used. 3127 3128 result.f is the fitted function (accepts x values). 3129 """ 3130
3131 - def fn(self, x, a, b, c,d):
3132 return a*x**3+b*x*x+c*x+d
3133
3134 - def fit(self, xs, ys, pars_ic=None, opts=None):
3135 if pars_ic is None: 3136 if self.pars_ic is None: 3137 pars_ic = array([1.,1.,1.,0.]) 3138 else: 3139 pars_ic = self.pars_ic 3140 3141 res = self._do_fit(None, xs, ys, pars_ic) 3142 sol = res[0] 3143 a,b,c,d = sol 3144 def f(x): 3145 return a*x**3+b*x*x+c*x+d
3146 ys_fit = f(xs) 3147 return args(ys_fit=ys_fit, pars_fit=(a,b,c,d), info=res, 3148 results=args(f=f))
3149 3150 3151
3152 -class fit_exponential(fit_function):
3153 """Fit an exponential function y=a*exp(b*x) to the (x,y) array data. 3154 If initial parameter values = (a,b) are not given, the values 3155 (1,-1) will be used. 3156 3157 result.f is the fitted function (accepts x values). 3158 """ 3159
3160 - def fn(self, x, a, b):
3161 return a*exp(b*x)
3162
3163 - def fit(self, xs, ys, pars_ic=None, opts=None):
3164 if pars_ic is None: 3165 if self.pars_ic is None: 3166 pars_ic = array([1.,-1.]) 3167 else: 3168 pars_ic = self.pars_ic 3169 3170 res = self._do_fit(None, xs, ys, pars_ic) 3171 sol = res[0] 3172 a,b = sol 3173 def f(x): 3174 return a*exp(b*x)
3175 ys_fit = f(xs) 3176 return args(ys_fit=ys_fit, pars_fit=(a,b), info=res, 3177 results=args(f=f))
3178 3179
3180 -class fit_diff_of_exp(fit_function):
3181 """Fit a 'difference of two exponentials' function 3182 y = k*a*b*(exp(-a*x)-exp(-b*x))/(b-a) to the (x,y) array data. 3183 If initial parameter values = (k,a,b) are not given, the values 3184 (1,1,1) will be used (where the function degenerates to 3185 y = k*a*a*x*exp(-a*x). 3186 3187 Optional use_xoff feature adds offset to x, so that 3188 y = k*a*a*(x+xoff)*exp(-a*(x+xoff)) (yes, "+ xoff") 3189 etc., in case fitting data that starts at larger values than its tail. 3190 Then initial parameter values will be (1,1,1,0) unless given otherwise. 3191 3192 If peak_constraint option is used, it is a tuple of values (x_index, 3193 y_value, weight_x, weight_y) for the approximate position of a turning point 3194 in the data, then this will be used as a soft constraint in the fit. 3195 3196 result.peak_pos is a (xpeak, ypeak) pair. 3197 result.f is the fitted function (accepts x values). 3198 """ 3199
3200 - def fn(self, x, k, a, b, xoff=0):
3201 if a==b: 3202 # classic "alpha" function 3203 return k*a*a*((x+xoff)*exp(-a*(x+xoff)) - xoff*exp(-a*xoff)) 3204 else: 3205 return k*a*b*(exp(-a*(x+xoff))+exp(-a*xoff)-exp(-b*(x+xoff))-exp(-b*xoff))/(b-a)
3206
3207 - def fit(self, xs, ys, pars_ic=None, opts=None):
3208 try: 3209 peak_constraint = opts.peak_constraint 3210 except AttributeError: 3211 peak_constraint = None 3212 try: 3213 use_xoff = opts.use_xoff 3214 except AttributeError: 3215 use_xoff = False 3216 3217 def peak_pos(k, a, b, xoff=0): 3218 if a==b: 3219 return 1./a - xoff 3220 else: 3221 return ((b-a)*xoff+log(a/b))/(a-b)
3222 3223 if pars_ic is None: 3224 if self.pars_ic is None: 3225 if use_xoff: 3226 pars_ic = array([1.,1.,1.,0.]) 3227 else: 3228 pars_ic = array([1.,1.,1.]) 3229 else: 3230 pars_ic = self.pars_ic 3231 if (len(self.pars_ic) == 4 and not use_xoff) or \ 3232 (len(self.pars_ic) == 3 and use_xoff): 3233 raise ValueError("Inconsistent use_xoff setting with pars_ic") 3234 3235 if peak_constraint is None: 3236 constraint = None 3237 else: 3238 x_index, y_value, weight_x, weight_y = peak_constraint 3239 def constraint(k,a,b,xoff=0): 3240 return array([weight_y*(self.fn(xs[x_index],k,a,b,xoff)-y_value), 3241 weight_x*(xs[x_index]-peak_pos(k,a,b,xoff))])
3242 res = self._do_fit(constraint, xs, ys, pars_ic) 3243 sol = res[0] 3244 if use_xoff: 3245 k,a,b,xoff = sol 3246 else: 3247 k,a,b = sol 3248 xoff = 0 3249 if xoff == 0: 3250 if a == b: 3251 # exceptional case 3252 def f(x): 3253 return k*a*a*x*exp(-a*x) 3254 else: 3255 def f(x): 3256 return k*a*b*(exp(-a*x)-exp(-b*x))/(b-a) 3257 else: 3258 if a == b: 3259 # exceptional case 3260 def f(x): 3261 return k*a*a*((x+xoff)*exp(-a*(x+xoff)) - xoff*exp(-a*xoff)) 3262 else: 3263 def f(x): 3264 return k*a*b*(exp(-a*(x+xoff))+exp(-a*xoff)-exp(-b*(x+xoff))-exp(-b*xoff))/(b-a) 3265 ys_fit = f(xs) 3266 xpeak = peak_pos(k,a,b,xoff) 3267 ypeak = f(xpeak) 3268 if use_xoff: 3269 pars_fit = (k, a, b, xoff) 3270 else: 3271 pars_fit = (k, a, b) 3272 return args(ys_fit=ys_fit, pars_fit=pars_fit, info=res, 3273 results=args(peak=(xpeak, ypeak), 3274 f=f)) 3275
3276 -class fit_linear(fit_function):
3277 """Fit a linear function y=a*x+b to the (x,y) array data. 3278 If initial parameter values = (a,b) are not given, the values 3279 (1,0) will be used. 3280 3281 result.f is the fitted function (accepts x values). 3282 """ 3283
3284 - def fn(self, x, a, b):
3285 return a*x+b
3286
3287 - def fit(self, xs, ys, pars_ic=None, opts=None):
3288 if pars_ic is None: 3289 if self.pars_ic is None: 3290 pars_ic = array([1.,0.]) 3291 else: 3292 pars_ic = self.pars_ic 3293 3294 res = self._do_fit(None, xs, ys, pars_ic) 3295 sol = res[0] 3296 a,b = sol 3297 def f(x): 3298 return a*x+b
3299 ys_fit = f(xs) 3300 return args(ys_fit=ys_fit, pars_fit=(a,b), info=res, 3301 results=args(f=f))
3302
3303 -def make_poly_interpolated_curve(pts, coord, model):
3304 """Only for a 1D curve from a Model object (that has an associated 3305 vector field for defining 1st derivative of curve). 3306 """ 3307 coord_ix = pts.coordnames.index(coord) 3308 x = pts[coord] 3309 t = pts.indepvararray 3310 p = model.query('pars') 3311 dx = array([model.Rhs(tval, pts[tix], p, asarray=True)[coord_ix] for \ 3312 tix, tval in enumerate(t)]) 3313 return PiecewisePolynomial(t, array([x, dx]).T, 2)
3314
3315 -def smooth_pts(t, x, q=None):
3316 """Use a local quadratic fit on a set of nearby 1D points and obtain 3317 a function that represents that fit in that neighbourhood. Returns a 3318 structure (args object) with attributes ys_fit, pars_fit, info, and 3319 results. The function can be referenced as results.f 3320 3321 Assumed that pts is small enough that it is either purely concave up or 3322 down but that at it contains at least five points. 3323 3324 If this function is used repeatedly, pass a fit_quadratic instance 3325 as the argument q 3326 """ 3327 ## Uncomment verbose-related statements for debugging 3328 # verbose = True 3329 if q is None: 3330 q = fit_quadratic(verbose=False) # verbose=verbose 3331 ixlo = 0 3332 ixhi = len(t)-1 3333 assert ixhi >= 4, "Provide at least five points" 3334 # concavity assumed to be simple: whether midpoint of x 3335 # is above or below the chord between the endpoints 3336 midpoint_ix = int(ixhi/2.) 3337 midpoint_chord = x[0]+(t[midpoint_ix]-t[0])*(x[-1]-x[0])/(t[-1]-t[0]) 3338 midpoint_x = x[midpoint_ix] 3339 # a_sign is -1 if concave down 3340 a_sign = sign(midpoint_chord - midpoint_x) 3341 ixmax = argmax(x) 3342 ixmin = argmin(x) 3343 # to estimate |a| need to know where best to put centre for 3344 # central second difference formula: 3345 # if extremum not at endpoints then use one endpoint 3346 # else use central point 3347 if (ixmax in (ixhi, ixlo) and a_sign == -1) or \ 3348 (ixmin in (ixhi, ixlo) and a_sign == 1): 3349 # use central point, guaranteed to be at least 2 indices away from 3350 # ends 3351 ix_cent = midpoint_ix 3352 else: 3353 # use an endpoint + 1 3354 ix_cent = ixlo+2 3355 # use mean of right and left t steps as h (should be safe for 3356 # smooth enough data) 3357 h = 0.25*(t[ix_cent+2]-t[ix_cent-2]) 3358 second_diff = (-x[ix_cent-2]+16*x[ix_cent-1]-30*x[ix_cent]+\ 3359 +16*x[ix_cent+1]-x[ix_cent+2])/(12*h**2) 3360 assert sign(second_diff) == a_sign, "Data insufficiently smooth" 3361 # a_est based on second deriv of quadratic formula = 2a 3362 a_est = second_diff/2. 3363 if a_sign == -1: 3364 extreme_x = x[ixmin] 3365 extreme_t = t[ixmin] 3366 else: 3367 extreme_x = x[ixmax] 3368 extreme_t = t[ixmax] 3369 # using vertex form of quadratic, x = a*( t-extreme_t )^2 + extreme_x 3370 # then in regular formula x = at^2 + bt + c used by quadratic fit class, 3371 # b = -2*extreme_t, and c = a*extreme_t^2 + extreme_x 3372 b_est = -2*extreme_t 3373 c_est = a_est*extreme_t**2 + extreme_x 3374 return q.fit(t, x, pars_ic=(a_est,b_est,c_est))
3375 # for debugging, set res = q.fit() and then return it after the following... 3376 # if verbose: 3377 # print "h =", h 3378 # print "a_est =", a_est, "b_est =", b_est, "c_est =", c_est 3379 # print "extremum estimate at (%f,%f)"%(extreme_t,extreme_x) 3380 # plot(t, x, 'go-') 3381 # tval, xval = res.results.peak 3382 # plot(tval, xval, 'rx') 3383 # xs_fit = res.ys_fit 3384 # plot(t, xs_fit, 'k:') 3385
3386 -def nearest_2n_indices(x, i, n):
3387 """Calculates the nearest 2n indices centred at i in an array x, or as close 3388 as possible to i, taking into account that i might be within n indices of 3389 an endpoint of x. 3390 3391 The function returns the limiting indices as a pair, and always returns 3392 an interval that contains 2n+1 indices, assuming x is long enough. 3393 3394 I.e., away from endpoints, the function returns (i-n, i+n). 3395 If i is within n of index 0, the function returns (0, 2n). 3396 If i is within n of last index L, the function returns (L-2n, L). 3397 3398 Remember to add one to the upper limit if using it in a slice. 3399 """ 3400 assert len(x) > 2*n, "x is not long enough" 3401 # ixlo = 0 3402 ixhi = len(x)-1 3403 if i < n: 3404 # too close to low end 3405 return (0, 2*n) 3406 elif i > ixhi - n: 3407 # too close to high end 3408 return (ixhi-2*n, ixhi) 3409 else: 3410 return (i-n, i+n)
3411 3412 3413 # -------------------------------------------------------------------- 3414
3415 -class DomainType(object):
3416 - def __init__(self, name):
3417 self.name = name
3418
3419 - def __eq__(self, other):
3420 try: 3421 return self.name == other.name 3422 except: 3423 return False
3424
3425 - def __ne__(self, other):
3426 try: 3427 return self.name != other.name 3428 except: 3429 return False
3430
3431 - def __repr__(self):
3432 return self.name
3433 3434 __str__ = __repr__
3435 3436 # treat these as "constants" as they are empty 3437 global Continuous, Discrete 3438 3439 Continuous = DomainType("Continuous Domain") 3440 Discrete = DomainType("Discrete Domain") 3441 3442 3443 #----------------------------------------------------------------------------- 3444 # The following code, in particular the Verbose class, was written by 3445 # John D. Hunter as part of the front end of MatplotLib. 3446 # (See matplotlib.sourceforge.net/license.html for details.) 3447 # 3448 # Copyright (c) 2002-2004 John D. Hunter; All Rights Reserved 3449 #----------------------------------------------------------------------------- 3450 3451 # This is not yet used in PyDSTool
3452 -class Verbose(object):
3453 """ 3454 A class to handle reporting. Set the fileo attribute to any file 3455 instance to handle the output. Default is sys.stdout 3456 """ 3457 levels = ('silent', 'error', 'helpful', 'debug', 'debug-annoying') 3458 vald = dict( [(level, i) for i,level in enumerate(levels)]) 3459 3460 # parse the verbosity from the command line; flags look like 3461 # --verbose-error or --verbose-helpful 3462 _commandLineVerbose = None 3463 3464 3465 for arg in sys.argv[1:]: 3466 if not arg.startswith('--verbose-'): continue 3467 _commandLineVerbose = arg[10:] 3468
3469 - def __init__(self, level):
3470 self.setLevel(level) 3471 self.fileo = sys.stdout 3472 self.erro = sys.stderr
3473
3474 - def setLevel(self, level):
3475 'set the verbosity to one of the Verbose.levels strings' 3476 3477 if self._commandLineVerbose is not None: 3478 level = self._commandLineVerbose 3479 if level not in self.levels: 3480 raise ValueError('Illegal verbose string "%s". Legal values are %s'%(level, self.levels)) 3481 self.level = level
3482
3483 - def report(self, s, level='helpful'):
3484 """ 3485 print message s to self.fileo if self.level>=level. Return 3486 value indicates whether a message was issue. 3487 """ 3488 if self.ge(level): 3489 print >>self.fileo, s 3490 return True 3491 return False
3492
3493 - def report_error(self, s):
3494 """ 3495 print message s to self.fileo if self.level>=level. Return 3496 value indicates whether a message was issued 3497 """ 3498 if self.ge('error'): 3499 print >>self.erro, s 3500 return True 3501 return False
3502 3503
3504 - def wrap(self, fmt, func, level='helpful', always=True):
3505 """ 3506 return a callable function that wraps func and reports it 3507 output through the verbose handler if current verbosity level 3508 is higher than level 3509 3510 if always is True, the report will occur on every function 3511 call; otherwise only on the first time the function is called 3512 """ 3513 assert callable(func) 3514 def wrapper(*args, **kwargs): 3515 ret = func(*args, **kwargs) 3516 3517 if (always or not wrapper._spoke): 3518 spoke = self.report(fmt%ret, level) 3519 if not wrapper._spoke: wrapper._spoke = spoke 3520 return ret
3521 wrapper._spoke = False 3522 wrapper.__doc__ = func.__doc__ 3523 return wrapper
3524
3525 - def ge(self, level):
3526 'return true if self.level is >= level' 3527 return self.vald[self.level]>=self.vald[level]
3528