Skip to content

Commit

Permalink
Merge pull request #108 from moeyensj/v2.0-obs-source
Browse files Browse the repository at this point in the history
V2.0 obs source
  • Loading branch information
moeyensj authored Aug 31, 2023
2 parents 99c110f + cb8e4cc commit ddcda9f
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 2 deletions.
132 changes: 132 additions & 0 deletions thor/observation_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import abc
from typing import TypeAlias

import numpy as np
import quivr as qv
from adam_core.observations import detections, exposures

from . import orbit


class Observations:
"""Observations represents a collection of exposures and the
detections they contain.
The detections may be a filtered subset of the detections in the
exposures.
"""

detections: detections.PointSourceDetections
exposures: exposures.Exposures
linkage: qv.Linkage[detections.PointSourceDetections, exposures.Exposures]

def __init__(
self,
detections: detections.PointSourceDetections,
exposures: exposures.Exposures,
):
self.detections = detections
self.exposures = exposures
self.linkage = qv.Linkage(
detections,
exposures,
left_keys=detections.exposure_id,
right_keys=exposures.id,
)


class ObservationSource(abc.ABC):
"""An ObservationSource is a source of observations for a given test orbit.
It has one method, gather_observations, which takes a test orbit
and returns a collection of Observations.
"""

@abc.abstractmethod
def gather_observations(self, test_orbit: orbit.TestOrbit) -> Observations:
pass


class FixedRadiusObservationSource(ObservationSource):
"""A FixedRadiusObservationSource is an ObservationSource that
gathers observations within a fixed radius of the test orbit's
ephemeris at each exposure time within a collection of exposures.
"""

def __init__(self, radius: float, all_observations: Observations):
"""
radius: The radius of the cell in degrees
"""
self.radius = radius
self.all_observations = all_observations

def gather_observations(self, test_orbit: orbit.TestOrbit) -> Observations:
# Generate an ephemeris for every observer time/location in the dataset
observers = self.all_observations.exposures.observers()
ephems_linkage = test_orbit.generate_ephemeris(
observers=observers,
)

matching_detections = detections.PointSourceDetections.empty()
matching_exposures = exposures.Exposures.empty()

# Build a mapping of exposure_id to ephemeris ra and dec
for exposure in self.all_observations.exposures:
key = ephems_linkage.key(
code=exposure.observatory_code[0].as_py(),
mjd=exposure.midpoint().mjd()[0].as_py(),
)
ephem = ephems_linkage.select_left(key)
assert len(ephem) == 1, "there should be exactly one ephemeris per exposure"

ephem_ra = ephem.coordinates.lon[0].as_py()
ephem_dec = ephem.coordinates.lat[0].as_py()

exp_dets = self.all_observations.linkage.select_left(exposure.id[0])

nearby_dets = _within_radius(exp_dets, ephem_ra, ephem_dec, self.radius)
if len(nearby_dets) > 0:
matching_exposures = qv.concatenate([matching_exposures, exposure])
matching_detections = qv.concatenate([matching_detections, nearby_dets])

return Observations(matching_detections, matching_exposures)


def _within_radius(
detections: detections.PointSourceDetections,
ra: float,
dec: float,
radius: float,
) -> detections.PointSourceDetections:
"""
Return the detections within a given radius of a given ra and dec
"""
sdlon = np.sin(detections.ra.to_numpy() - ra)
cdlon = np.cos(detections.ra.to_numpy() - ra)
slat1 = np.sin(dec)
slat2 = np.sin(detections.dec.to_numpy())
clat1 = np.cos(dec)
clat2 = np.cos(detections.dec.to_numpy())

num1 = clat2 * sdlon
num2 = clat1 * slat2 - slat1 * clat2 * cdlon
denominator = slat1 * slat2 + clat1 * clat2 * cdlon

distances = np.arctan2(np.hypot(num1, num2), denominator)

mask = distances <= radius
return detections.apply_mask(mask)


class StaticObservationSource(ObservationSource):
"""A StaticObservationSource is an ObservationSource that
returns a fixed collection of observations for any test orbit.
"""

def __init__(self, observations: Observations):
self.observations = observations

def gather_observations(self, test_orbit: orbit.TestOrbit) -> Observations:
return self.observations
3 changes: 1 addition & 2 deletions thor/orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def __init__(
else:
self.orbit_id = uuid.uuid4().hex

if object_id is not None:
self.object_id = object_id
self.object_id = object_id

self._orbit = Orbits.from_kwargs(
orbit_id=[self.orbit_id],
Expand Down
123 changes: 123 additions & 0 deletions thor/tests/test_observation_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import astropy.time
import numpy as np
import pyarrow as pa
import pytest
import quivr as qv
from adam_core import coordinates, observers, propagator
from adam_core.observations import detections, exposures

from .. import observation_source, orbit


@pytest.fixture
def fixed_test_orbit():
# An orbit at 1AU going around at (about) 1 degree per day
coords = coordinates.CartesianCoordinates.from_kwargs(
x=[1],
y=[0],
z=[0],
vx=[0],
vy=[2*np.pi/365.25],
vz=[0],
time=coordinates.Times.from_astropy(astropy.time.Time("2020-01-01T00:00:00")),
origin=coordinates.Origin.from_kwargs(code=["SUN"]),
frame="ecliptic",
)

return orbit.TestOrbit(
coordinates=coords,
orbit_id="test_orbit",
)


@pytest.fixture
def fixed_observers():
times = astropy.time.Time(
[
"2020-01-01T00:00:00",
"2020-01-02T00:00:00",
"2020-01-03T00:00:00",
"2020-01-04T00:00:00",
"2020-01-05T00:00:00",
]
)
return observers.Observers.from_code("I11", times)


@pytest.fixture
def fixed_ephems(fixed_test_orbit, fixed_observers):
prop = propagator.PYOORB()
return prop.generate_ephemeris(fixed_test_orbit.orbit, fixed_observers).left_table


@pytest.fixture
def fixed_exposures(fixed_observers):
return exposures.Exposures.from_kwargs(
id=[str(i) for i in range(len(fixed_observers))],
start_time=fixed_observers.coordinates.time,
duration=[30 for i in range(len(fixed_observers))],
filter=["i" for i in range(len(fixed_observers))],
observatory_code=fixed_observers.code,
)


@pytest.fixture
def fixed_detections(fixed_ephems, fixed_exposures):
# Return PointSourceDetections which form a 100 x 100 grid in
# RA/Dec, evenly spanning 1 square degree, for each exposure
detection_tables = []
for ephem, exposure in zip(fixed_ephems, fixed_exposures):
ra_center = ephem.coordinates.lon[0].as_py()
dec_center = ephem.coordinates.lat[0].as_py()

ras = np.linspace(ra_center - 0.5, ra_center + 0.5, 100)
decs = np.linspace(dec_center - 0.5, dec_center + 0.5, 100)

ra_decs = np.meshgrid(ras, decs)

N = len(ras) * len(decs)
ids = [str(i) for i in range(N)]
exposure_ids = pa.concat_arrays([exposure.id] * N)
magnitudes = [20] * N

detection_tables.append(
detections.PointSourceDetections.from_kwargs(
id=ids,
exposure_id=exposure_ids,
ra=ra_decs[0].flatten(),
dec=ra_decs[1].flatten(),
mag=magnitudes,
)
)
return qv.concatenate(detection_tables)


@pytest.fixture
def fixed_observations(fixed_detections, fixed_exposures):
return observation_source.Observations(fixed_detections, fixed_exposures)


def test_observation_fixtures(fixed_test_orbit, fixed_observations):
assert len(fixed_test_orbit.orbit) == 1
assert len(fixed_observations.exposures) == 5
assert len(fixed_observations.detections) == 100 * 100 * 5


def test_static_observation_source(fixed_test_orbit, fixed_observations):
sos = observation_source.StaticObservationSource(observations=fixed_observations)
have = sos.gather_observations(fixed_test_orbit)

assert have == fixed_observations


def test_fixed_radius_observation_source(fixed_test_orbit, fixed_observations):
fos = observation_source.FixedRadiusObservationSource(
radius=0.5,
all_observations=fixed_observations,
)
have = fos.gather_observations(fixed_test_orbit)
assert len(have.exposures) == 5
assert have.exposures == fixed_observations.exposures
assert len(have.detections) < len(fixed_observations.detections)
assert len(have.detections) > 0.75 * len(fixed_observations.detections)

0 comments on commit ddcda9f

Please sign in to comment.