-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #108 from moeyensj/v2.0-obs-source
V2.0 obs source
- Loading branch information
Showing
3 changed files
with
256 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import abc | ||
from typing import TypeAlias | ||
|
||
import numpy as np | ||
import quivr as qv | ||
from adam_core.observations import detections, exposures | ||
|
||
from . import orbit | ||
|
||
|
||
class Observations: | ||
"""Observations represents a collection of exposures and the | ||
detections they contain. | ||
The detections may be a filtered subset of the detections in the | ||
exposures. | ||
""" | ||
|
||
detections: detections.PointSourceDetections | ||
exposures: exposures.Exposures | ||
linkage: qv.Linkage[detections.PointSourceDetections, exposures.Exposures] | ||
|
||
def __init__( | ||
self, | ||
detections: detections.PointSourceDetections, | ||
exposures: exposures.Exposures, | ||
): | ||
self.detections = detections | ||
self.exposures = exposures | ||
self.linkage = qv.Linkage( | ||
detections, | ||
exposures, | ||
left_keys=detections.exposure_id, | ||
right_keys=exposures.id, | ||
) | ||
|
||
|
||
class ObservationSource(abc.ABC): | ||
"""An ObservationSource is a source of observations for a given test orbit. | ||
It has one method, gather_observations, which takes a test orbit | ||
and returns a collection of Observations. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def gather_observations(self, test_orbit: orbit.TestOrbit) -> Observations: | ||
pass | ||
|
||
|
||
class FixedRadiusObservationSource(ObservationSource): | ||
"""A FixedRadiusObservationSource is an ObservationSource that | ||
gathers observations within a fixed radius of the test orbit's | ||
ephemeris at each exposure time within a collection of exposures. | ||
""" | ||
|
||
def __init__(self, radius: float, all_observations: Observations): | ||
""" | ||
radius: The radius of the cell in degrees | ||
""" | ||
self.radius = radius | ||
self.all_observations = all_observations | ||
|
||
def gather_observations(self, test_orbit: orbit.TestOrbit) -> Observations: | ||
# Generate an ephemeris for every observer time/location in the dataset | ||
observers = self.all_observations.exposures.observers() | ||
ephems_linkage = test_orbit.generate_ephemeris( | ||
observers=observers, | ||
) | ||
|
||
matching_detections = detections.PointSourceDetections.empty() | ||
matching_exposures = exposures.Exposures.empty() | ||
|
||
# Build a mapping of exposure_id to ephemeris ra and dec | ||
for exposure in self.all_observations.exposures: | ||
key = ephems_linkage.key( | ||
code=exposure.observatory_code[0].as_py(), | ||
mjd=exposure.midpoint().mjd()[0].as_py(), | ||
) | ||
ephem = ephems_linkage.select_left(key) | ||
assert len(ephem) == 1, "there should be exactly one ephemeris per exposure" | ||
|
||
ephem_ra = ephem.coordinates.lon[0].as_py() | ||
ephem_dec = ephem.coordinates.lat[0].as_py() | ||
|
||
exp_dets = self.all_observations.linkage.select_left(exposure.id[0]) | ||
|
||
nearby_dets = _within_radius(exp_dets, ephem_ra, ephem_dec, self.radius) | ||
if len(nearby_dets) > 0: | ||
matching_exposures = qv.concatenate([matching_exposures, exposure]) | ||
matching_detections = qv.concatenate([matching_detections, nearby_dets]) | ||
|
||
return Observations(matching_detections, matching_exposures) | ||
|
||
|
||
def _within_radius( | ||
detections: detections.PointSourceDetections, | ||
ra: float, | ||
dec: float, | ||
radius: float, | ||
) -> detections.PointSourceDetections: | ||
""" | ||
Return the detections within a given radius of a given ra and dec | ||
""" | ||
sdlon = np.sin(detections.ra.to_numpy() - ra) | ||
cdlon = np.cos(detections.ra.to_numpy() - ra) | ||
slat1 = np.sin(dec) | ||
slat2 = np.sin(detections.dec.to_numpy()) | ||
clat1 = np.cos(dec) | ||
clat2 = np.cos(detections.dec.to_numpy()) | ||
|
||
num1 = clat2 * sdlon | ||
num2 = clat1 * slat2 - slat1 * clat2 * cdlon | ||
denominator = slat1 * slat2 + clat1 * clat2 * cdlon | ||
|
||
distances = np.arctan2(np.hypot(num1, num2), denominator) | ||
|
||
mask = distances <= radius | ||
return detections.apply_mask(mask) | ||
|
||
|
||
class StaticObservationSource(ObservationSource): | ||
"""A StaticObservationSource is an ObservationSource that | ||
returns a fixed collection of observations for any test orbit. | ||
""" | ||
|
||
def __init__(self, observations: Observations): | ||
self.observations = observations | ||
|
||
def gather_observations(self, test_orbit: orbit.TestOrbit) -> Observations: | ||
return self.observations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import astropy.time | ||
import numpy as np | ||
import pyarrow as pa | ||
import pytest | ||
import quivr as qv | ||
from adam_core import coordinates, observers, propagator | ||
from adam_core.observations import detections, exposures | ||
|
||
from .. import observation_source, orbit | ||
|
||
|
||
@pytest.fixture | ||
def fixed_test_orbit(): | ||
# An orbit at 1AU going around at (about) 1 degree per day | ||
coords = coordinates.CartesianCoordinates.from_kwargs( | ||
x=[1], | ||
y=[0], | ||
z=[0], | ||
vx=[0], | ||
vy=[2*np.pi/365.25], | ||
vz=[0], | ||
time=coordinates.Times.from_astropy(astropy.time.Time("2020-01-01T00:00:00")), | ||
origin=coordinates.Origin.from_kwargs(code=["SUN"]), | ||
frame="ecliptic", | ||
) | ||
|
||
return orbit.TestOrbit( | ||
coordinates=coords, | ||
orbit_id="test_orbit", | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def fixed_observers(): | ||
times = astropy.time.Time( | ||
[ | ||
"2020-01-01T00:00:00", | ||
"2020-01-02T00:00:00", | ||
"2020-01-03T00:00:00", | ||
"2020-01-04T00:00:00", | ||
"2020-01-05T00:00:00", | ||
] | ||
) | ||
return observers.Observers.from_code("I11", times) | ||
|
||
|
||
@pytest.fixture | ||
def fixed_ephems(fixed_test_orbit, fixed_observers): | ||
prop = propagator.PYOORB() | ||
return prop.generate_ephemeris(fixed_test_orbit.orbit, fixed_observers).left_table | ||
|
||
|
||
@pytest.fixture | ||
def fixed_exposures(fixed_observers): | ||
return exposures.Exposures.from_kwargs( | ||
id=[str(i) for i in range(len(fixed_observers))], | ||
start_time=fixed_observers.coordinates.time, | ||
duration=[30 for i in range(len(fixed_observers))], | ||
filter=["i" for i in range(len(fixed_observers))], | ||
observatory_code=fixed_observers.code, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def fixed_detections(fixed_ephems, fixed_exposures): | ||
# Return PointSourceDetections which form a 100 x 100 grid in | ||
# RA/Dec, evenly spanning 1 square degree, for each exposure | ||
detection_tables = [] | ||
for ephem, exposure in zip(fixed_ephems, fixed_exposures): | ||
ra_center = ephem.coordinates.lon[0].as_py() | ||
dec_center = ephem.coordinates.lat[0].as_py() | ||
|
||
ras = np.linspace(ra_center - 0.5, ra_center + 0.5, 100) | ||
decs = np.linspace(dec_center - 0.5, dec_center + 0.5, 100) | ||
|
||
ra_decs = np.meshgrid(ras, decs) | ||
|
||
N = len(ras) * len(decs) | ||
ids = [str(i) for i in range(N)] | ||
exposure_ids = pa.concat_arrays([exposure.id] * N) | ||
magnitudes = [20] * N | ||
|
||
detection_tables.append( | ||
detections.PointSourceDetections.from_kwargs( | ||
id=ids, | ||
exposure_id=exposure_ids, | ||
ra=ra_decs[0].flatten(), | ||
dec=ra_decs[1].flatten(), | ||
mag=magnitudes, | ||
) | ||
) | ||
return qv.concatenate(detection_tables) | ||
|
||
|
||
@pytest.fixture | ||
def fixed_observations(fixed_detections, fixed_exposures): | ||
return observation_source.Observations(fixed_detections, fixed_exposures) | ||
|
||
|
||
def test_observation_fixtures(fixed_test_orbit, fixed_observations): | ||
assert len(fixed_test_orbit.orbit) == 1 | ||
assert len(fixed_observations.exposures) == 5 | ||
assert len(fixed_observations.detections) == 100 * 100 * 5 | ||
|
||
|
||
def test_static_observation_source(fixed_test_orbit, fixed_observations): | ||
sos = observation_source.StaticObservationSource(observations=fixed_observations) | ||
have = sos.gather_observations(fixed_test_orbit) | ||
|
||
assert have == fixed_observations | ||
|
||
|
||
def test_fixed_radius_observation_source(fixed_test_orbit, fixed_observations): | ||
fos = observation_source.FixedRadiusObservationSource( | ||
radius=0.5, | ||
all_observations=fixed_observations, | ||
) | ||
have = fos.gather_observations(fixed_test_orbit) | ||
assert len(have.exposures) == 5 | ||
assert have.exposures == fixed_observations.exposures | ||
assert len(have.detections) < len(fixed_observations.detections) | ||
assert len(have.detections) > 0.75 * len(fixed_observations.detections) | ||
|