diff --git a/skyfield/keplerlib.py b/skyfield/keplerlib.py index 9ec83b68c..3faa9744d 100644 --- a/skyfield/keplerlib.py +++ b/skyfield/keplerlib.py @@ -5,7 +5,7 @@ from numpy import(abs, amax, amin, arange, arccos, arctan, array, atleast_1d, clip, copy, copyto, cos, cosh, exp, float64, full_like, log, ndarray, newaxis, pi, power, repeat, sin, sinh, squeeze, - sqrt, sum, tan, tanh, zeros_like) + sqrt, sum, tan, tanh, zeros_like, newaxis) from skyfield.constants import AU_KM, DAY_S, DEG2RAD from skyfield.functions import dots, length_of, mxv @@ -536,7 +536,10 @@ def kepler_1d(x, orb_inds): t0 = repeat(t0, position.shape[1]) # shape of 2 dimensional arrays from here on out should be (#orbits, len(t1)) - dt = t1 - t0[:, newaxis] + if t1.shape == t0.shape: + dt = t1[:, newaxis] - t0[:, newaxis] + else: + dt = t1 - t0[:, newaxis] x = dt/bq copyto(x, -bound, where=(x<-bound)) diff --git a/skyfield/vectorlib.py b/skyfield/vectorlib.py index 57ee37a0b..9382a826e 100644 --- a/skyfield/vectorlib.py +++ b/skyfield/vectorlib.py @@ -1,7 +1,7 @@ """Vector functions and their composition.""" from jplephem.names import target_names as _jpl_code_name_dict -from numpy import max +from numpy import max, newaxis from .constants import C_AUDAY from .descriptorlib import reify from .errors import DeprecationError @@ -215,8 +215,12 @@ def _at(self, t): p2, v2, another_gcrs_position, message = vf._at(t) if gcrs_position is None: # TODO: so bootleg; rework whole idea gcrs_position = another_gcrs_position - p += p2 - v += v2 + if not isinstance(p, float) and len(p2.shape) > len(p.shape): + p = p2 + p[:,newaxis] + v = v2 + v[:,newaxis] + else: + p += p2 + v += v2 return p, v, gcrs_position, message def _correct_for_light_travel_time(observer, target): @@ -238,7 +242,10 @@ def _correct_for_light_travel_time(observer, target): cvelocity = observer.velocity.au_per_d tposition, tvelocity, gcrs_position, message = target._at(t) - + if len(cposition.shape) != len(tposition.shape) and len(tposition.shape) == 2 and len(cposition.shape) == 1: + cposition = cposition[:,newaxis] + cvelocity = cvelocity[:,newaxis] + distance = length_of(tposition - cposition) light_time0 = 0.0 for i in range(10):