Use parametrize instead of yield.
authorkelle <kellecruz@gmail.com>
Sat, 7 Jan 2017 22:45:29 +0000 (17:45 -0500)
committerOle Streicher <olebole@debian.org>
Wed, 25 Jan 2017 15:17:26 +0000 (15:17 +0000)
Pull request: https://github.com/astropy/astropy/pull/5678
Pull request: https://github.com/astropy/astropy/pull/5682

Gbp-Pq: Name Use-parametrize-instead-of-yield.patch

astropy/io/votable/tests/vo_test.py
astropy/units/tests/test_format.py
astropy/wcs/tests/test_profiling.py
astropy/wcs/tests/test_wcs.py

index 587fc591661745fc2448bb6856a5552b3b3689f0..31acb9c9ec23d3f40fd78eb62213136f386d9624 100644 (file)
@@ -737,13 +737,10 @@ def table_from_scratch():
 
 
 def test_open_files():
-    def test_file(filename):
-        parse(filename, pedantic=False)
-
     for filename in get_pkg_data_filenames('data', pattern='*.xml'):
         if filename.endswith('custom_datatype.xml'):
             continue
-        yield test_file, filename
+        parse(filename, pedantic=False)
 
 
 @raises(VOTableSpecError)
index 9316a5969a3386eec6e41dc72bb926ae06a4875d..b5e9d12e5d8e08d71583f1446f7066a689cf5499 100644 (file)
@@ -13,7 +13,7 @@ from __future__ import (absolute_import, unicode_literals, division,
 from ...extern import six
 
 from numpy.testing.utils import assert_allclose
-from ...tests.helper import raises, pytest, catch_warnings
+from ...tests.helper import pytest, catch_warnings
 
 from ... import units as u
 from ...constants import si
@@ -22,246 +22,205 @@ from .. import format as u_format
 from ..utils import is_effectively_unity
 
 
-def test_unit_grammar():
-    def _test_unit_grammar(s, unit):
+@pytest.mark.parametrize('strings, unit', [
+    (["m s", "m*s", "m.s"], u.m * u.s),
+    (["m/s", "m*s**-1", "m /s", "m / s", "m/ s"], u.m / u.s),
+    (["m**2", "m2", "m**(2)", "m**+2", "m+2", "m^(+2)"], u.m ** 2),
+    (["m**-3", "m-3", "m^(-3)", "/m3"], u.m ** -3),
+    (["m**(1.5)", "m(3/2)", "m**(3/2)", "m^(3/2)"], u.m ** 1.5),
+    (["2.54 cm"], u.Unit(u.cm * 2.54)),
+    (["10+8m"], u.Unit(u.m * 1e8)),
+    # This is the VOUnits documentation, but doesn't seem to follow the
+    # unity grammar (["3.45 10**(-4)Jy"], 3.45 * 1e-4 * u.Jy)
+    (["sqrt(m)"], u.m ** 0.5),
+    (["dB(mW)", "dB (mW)"], u.DecibelUnit(u.mW)),
+    (["mag"], u.mag),
+    (["mag(ct/s)"], u.MagUnit(u.ct / u.s)),
+    (["dex"], u.dex),
+    (["dex(cm s**-2)", "dex(cm/s2)"], u.DexUnit(u.cm / u.s**2))])
+def test_unit_grammar(strings, unit):
+    for s in strings:
         print(s)
         unit2 = u_format.Generic.parse(s)
         assert unit2 == unit
 
-    data = [
-        (["m s", "m*s", "m.s"], u.m * u.s),
-        (["m/s", "m*s**-1", "m /s", "m / s", "m/ s"], u.m / u.s),
-        (["m**2", "m2", "m**(2)", "m**+2", "m+2", "m^(+2)"], u.m ** 2),
-        (["m**-3", "m-3", "m^(-3)", "/m3"], u.m ** -3),
-        (["m**(1.5)", "m(3/2)", "m**(3/2)", "m^(3/2)"], u.m ** 1.5),
-        (["2.54 cm"], u.Unit(u.cm * 2.54)),
-        (["10+8m"], u.Unit(u.m * 1e8)),
-        # This is the VOUnits documentation, but doesn't seem to follow the
-        # unity grammar (["3.45 10**(-4)Jy"], 3.45 * 1e-4 * u.Jy)
-        (["sqrt(m)"], u.m ** 0.5),
-        (["dB(mW)", "dB (mW)"], u.DecibelUnit(u.mW)),
-        (["mag"], u.mag),
-        (["mag(ct/s)"], u.MagUnit(u.ct / u.s)),
-        (["dex"], u.dex),
-        (["dex(cm s**-2)", "dex(cm/s2)"], u.DexUnit(u.cm / u.s**2))
-    ]
-
-    for strings, unit in data:
-        for s in strings:
-            yield _test_unit_grammar, s, unit
-
-
-def test_unit_grammar_fail():
-    @raises(ValueError)
-    def _test_unit_grammar_fail(s):
-        u_format.Generic.parse(s)
-
-    data = ['sin( /pixel /s)',
-            'mag(mag)',
-            'dB(dB(mW))',
-            'dex()']
-
-    for s in data:
-        yield _test_unit_grammar_fail, s
-
-
-def test_cds_grammar():
-    def _test_cds_grammar(s, unit):
+
+@pytest.mark.parametrize('string', ['sin( /pixel /s)', 'mag(mag)',
+                                    'dB(dB(mW))', 'dex()'])
+def test_unit_grammar_fail(string):
+    with pytest.raises(ValueError):
+        print(string)
+        u_format.Generic.parse(string)
+
+@pytest.mark.parametrize('strings, unit', [
+    (["0.1nm"], u.AA),
+    (["mW/m2"], u.Unit(u.erg / u.cm ** 2 / u.s)),
+    (["mW/(m2)"], u.Unit(u.erg / u.cm ** 2 / u.s)),
+    (["km/s", "km.s-1"], u.km / u.s),
+    (["10pix/nm"], u.Unit(10 * u.pix / u.nm)),
+    (["1.5x10+11m"], u.Unit(1.5e11 * u.m)),
+    (["1.5×10+11m"], u.Unit(1.5e11 * u.m)),
+    (["m2"], u.m ** 2),
+    (["10+21m"], u.Unit(u.m * 1e21)),
+    (["2.54cm"], u.Unit(u.cm * 2.54)),
+    (["20%"], 0.20 * u.dimensionless_unscaled),
+    (["10+9"], 1.e9 * u.dimensionless_unscaled),
+    (["2x10-9"], 2.e-9 * u.dimensionless_unscaled),
+    (["---"], u.dimensionless_unscaled),
+    (["ma"], u.ma),
+    (["mAU"], u.mAU),
+    (["uarcmin"], u.uarcmin),
+    (["uarcsec"], u.uarcsec),
+    (["kbarn"], u.kbarn),
+    (["Gbit"], u.Gbit),
+    (["Gibit"], 2 ** 30 * u.bit),
+    (["kbyte"], u.kbyte),
+    (["mRy"], 0.001 * u.Ry),
+    (["mmag"], u.mmag),
+    (["Mpc"], u.Mpc),
+    (["Gyr"], u.Gyr),
+    (["°"], u.degree),
+    (["°/s"], u.degree / u.s),
+    (["Å"], u.AA),
+    (["Å/s"], u.AA / u.s),
+    (["\\h"], si.h)])
+def test_cds_grammar(strings, unit):
+    for s in strings:
         print(s)
         unit2 = u_format.CDS.parse(s)
         assert unit2 == unit
 
-    data = [
-        (["0.1nm"], u.AA),
-        (["mW/m2"], u.Unit(u.erg / u.cm ** 2 / u.s)),
-        (["mW/(m2)"], u.Unit(u.erg / u.cm ** 2 / u.s)),
-        (["km/s", "km.s-1"], u.km / u.s),
-        (["10pix/nm"], u.Unit(10 * u.pix / u.nm)),
-        (["1.5x10+11m"], u.Unit(1.5e11 * u.m)),
-        (["1.5×10+11m"], u.Unit(1.5e11 * u.m)),
-        (["m2"], u.m ** 2),
-        (["10+21m"], u.Unit(u.m * 1e21)),
-        (["2.54cm"], u.Unit(u.cm * 2.54)),
-        (["20%"], 0.20 * u.dimensionless_unscaled),
-        (["10+9"], 1.e9 * u.dimensionless_unscaled),
-        (["2x10-9"], 2.e-9 * u.dimensionless_unscaled),
-        (["---"], u.dimensionless_unscaled),
-        (["ma"], u.ma),
-        (["mAU"], u.mAU),
-        (["uarcmin"], u.uarcmin),
-        (["uarcsec"], u.uarcsec),
-        (["kbarn"], u.kbarn),
-        (["Gbit"], u.Gbit),
-        (["Gibit"], 2 ** 30 * u.bit),
-        (["kbyte"], u.kbyte),
-        (["mRy"], 0.001 * u.Ry),
-        (["mmag"], u.mmag),
-        (["Mpc"], u.Mpc),
-        (["Gyr"], u.Gyr),
-        (["°"], u.degree),
-        (["°/s"], u.degree / u.s),
-        (["Å"], u.AA),
-        (["Å/s"], u.AA / u.s),
-        (["\\h"], si.h)]
-
-    for strings, unit in data:
-        for s in strings:
-            yield _test_cds_grammar, s, unit
-
-
-def test_cds_grammar_fail():
-    @raises(ValueError)
-    def _test_cds_grammar_fail(s):
-        print(s)
-        u_format.CDS.parse(s)
-
-    data = ['0.1 nm',
-            'solMass(3/2)',
-            'km / s',
-            'km s-1',
-            'pix0.1nm',
-            'pix/(0.1nm)',
-            'km*s',
-            'km**2',
-            '5x8+3m',
-            '0.1---',
-            '---m',
-            'm---',
-            'mag(s-1)',
-            'dB(mW)',
-            'dex(cm s-2)']
-
-    for s in data:
-        yield _test_cds_grammar_fail, s
-
-
-def test_ogip_grammar():
-    def _test_ogip_grammar(s, unit):
+
+@pytest.mark.parametrize('string', [
+    '0.1 nm',
+    'solMass(3/2)',
+    'km / s',
+    'km s-1',
+    'pix0.1nm',
+    'pix/(0.1nm)',
+    'km*s',
+    'km**2',
+    '5x8+3m',
+    '0.1---',
+    '---m',
+    'm---',
+    'mag(s-1)',
+    'dB(mW)',
+    'dex(cm s-2)'])
+def test_cds_grammar_fail(string):
+    with pytest.raises(ValueError):
+        print(string)
+        u_format.CDS.parse(string)
+
+
+# These examples are taken from the EXAMPLES section of
+# http://heasarc.gsfc.nasa.gov/docs/heasarc/ofwg/docs/general/ogip_93_001/
+@pytest.mark.parametrize('strings, unit', [
+    (["count /s", "count/s", "count s**(-1)", "count / s", "count /s "],
+     u.count / u.s),
+    (["/pixel /s", "/(pixel * s)"], (u.pixel * u.s) ** -1),
+    (["count /m**2 /s /eV", "count m**(-2) * s**(-1) * eV**(-1)",
+      "count /(m**2 * s * eV)"],
+     u.count * u.m ** -2 * u.s ** -1 * u.eV ** -1),
+    (["erg /pixel /s /GHz", "erg /s /GHz /pixel", "erg /pixel /(s * GHz)"],
+     u.erg / (u.s * u.GHz * u.pixel)),
+    (["keV**2 /yr /angstrom", "10**(10) keV**2 /yr /m"],
+      # Though this is given as an example, it seems to violate the rules
+      # of not raising scales to powers, so I'm just excluding it
+      # "(10**2 MeV)**2 /yr /m"
+     u.keV**2 / (u.yr * u.angstrom)),
+    (["10**(46) erg /s", "10**46 erg /s", "10**(39) J /s", "10**(39) W",
+      "10**(15) YW", "YJ /fs"],
+     10**46 * u.erg / u.s),
+    (["10**(-7) J /cm**2 /MeV", "10**(-9) J m**(-2) eV**(-1)",
+      "nJ m**(-2) eV**(-1)", "nJ /m**2 /eV"],
+     10 ** -7 * u.J * u.cm ** -2 * u.MeV ** -1),
+    (["sqrt(erg /pixel /s /GHz)", "(erg /pixel /s /GHz)**(0.5)",
+      "(erg /pixel /s /GHz)**(1/2)",
+      "erg**(0.5) pixel**(-0.5) s**(-0.5) GHz**(-0.5)"],
+     (u.erg * u.pixel ** -1 * u.s ** -1 * u.GHz ** -1) ** 0.5),
+    (["(count /s) (/pixel /s)", "(count /s) * (/pixel /s)",
+      "count /pixel /s**2"],
+     (u.count / u.s) * (1.0 / (u.pixel * u.s)))])
+def test_ogip_grammar(strings, unit):
+    for s in strings:
         print(s)
         unit2 = u_format.OGIP.parse(s)
         assert unit2 == unit
 
-    # These examples are taken from the EXAMPLES section of
-    # http://heasarc.gsfc.nasa.gov/docs/heasarc/ofwg/docs/general/ogip_93_001/
-    data = [
-        (["count /s", "count/s", "count s**(-1)", "count / s", "count /s "],
-         u.count / u.s),
-        (["/pixel /s", "/(pixel * s)"], (u.pixel * u.s) ** -1),
-        (["count /m**2 /s /eV", "count m**(-2) * s**(-1) * eV**(-1)",
-          "count /(m**2 * s * eV)"],
-         u.count * u.m ** -2 * u.s ** -1 * u.eV ** -1),
-        (["erg /pixel /s /GHz", "erg /s /GHz /pixel", "erg /pixel /(s * GHz)"],
-         u.erg / (u.s * u.GHz * u.pixel)),
-        (["keV**2 /yr /angstrom", "10**(10) keV**2 /yr /m",
-          # Though this is given as an example, it seems to violate the rules
-          # of not raising scales to powers, so I'm just excluding it
-          # "(10**2 MeV)**2 /yr /m"
-         ],
-         u.keV**2 / (u.yr * u.angstrom)),
-        (["10**(46) erg /s", "10**46 erg /s", "10**(39) J /s", "10**(39) W",
-          "10**(15) YW", "YJ /fs"],
-         10**46 * u.erg / u.s),
-        (["10**(-7) J /cm**2 /MeV", "10**(-9) J m**(-2) eV**(-1)",
-          "nJ m**(-2) eV**(-1)", "nJ /m**2 /eV"],
-         10 ** -7 * u.J * u.cm ** -2 * u.MeV ** -1),
-        (["sqrt(erg /pixel /s /GHz)", "(erg /pixel /s /GHz)**(0.5)",
-          "(erg /pixel /s /GHz)**(1/2)",
-          "erg**(0.5) pixel**(-0.5) s**(-0.5) GHz**(-0.5)"],
-         (u.erg * u.pixel ** -1 * u.s ** -1 * u.GHz ** -1) ** 0.5),
-        (["(count /s) (/pixel /s)", "(count /s) * (/pixel /s)",
-          "count /pixel /s**2"],
-         (u.count / u.s) * (1.0 / (u.pixel * u.s)))]
-
-    for strings, unit in data:
-        for s in strings:
-            yield _test_ogip_grammar, s, unit
-
-
-def test_ogip_grammar_fail():
-    @raises(ValueError)
-    def _test_ogip_grammar_fail(s):
-        u_format.OGIP.parse(s)
-
-    data = ['log(photon /m**2 /s /Hz)',
-            'sin( /pixel /s)',
-            'log(photon /cm**2 /s /Hz) /(sin( /pixel /s))',
-            'log(photon /cm**2 /s /Hz) (sin( /pixel /s))**(-1)',
-            'dB(mW)', 'dex(cm/s**2)']
-
-    for s in data:
-        yield _test_ogip_grammar_fail, s
-
-
-def test_roundtrip():
-    def _test_roundtrip(unit):
-        a = core.Unit(unit.to_string('generic'), format='generic')
-        b = core.Unit(unit.decompose().to_string('generic'), format='generic')
-        assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2)
-        assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2)
-
-    for key, val in u.__dict__.items():
-        if isinstance(val, core.UnitBase) and not isinstance(val, core.PrefixUnit):
-            yield _test_roundtrip, val
-
-
-def test_roundtrip_vo_unit():
-    def _test_roundtrip_vo_unit(unit, skip_decompose):
-        a = core.Unit(unit.to_string('vounit'), format='vounit')
-        assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2)
-        if skip_decompose:
-            return
-        u = unit.decompose().to_string('vounit')
-        assert '  ' not in u
-        b = core.Unit(u, format='vounit')
-        assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2)
-
-    x = u_format.VOUnit
-    for key, val in x._units.items():
-        if isinstance(val, core.UnitBase) and not isinstance(val, core.PrefixUnit):
-            yield _test_roundtrip_vo_unit, val, val in (u.mag, u.dB)
-
-
-def test_roundtrip_fits():
-    def _test_roundtrip_fits(unit):
-        s = unit.to_string('fits')
-        a = core.Unit(s, format='fits')
-        assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2)
-
-    for key, val in u_format.Fits._units.items():
-        if isinstance(val, core.UnitBase) and not isinstance(val, core.PrefixUnit):
-            yield _test_roundtrip_fits, val
-
 
-def test_roundtrip_cds():
-    def _test_roundtrip_cds(unit):
-        a = core.Unit(unit.to_string('cds'), format='cds')
-        assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2)
-        try:
-            b = core.Unit(unit.decompose().to_string('cds'), format='cds')
-        except ValueError:  # skip mag: decomposes into dex, unknown to OGIP
-            return
+@pytest.mark.parametrize('string', [
+    'log(photon /m**2 /s /Hz)',
+    'sin( /pixel /s)',
+    'log(photon /cm**2 /s /Hz) /(sin( /pixel /s))',
+    'log(photon /cm**2 /s /Hz) (sin( /pixel /s))**(-1)',
+    'dB(mW)', 'dex(cm/s**2)'])
+def test_ogip_grammar_fail(string):
+    with pytest.raises(ValueError):
+        print(string)
+        u_format.OGIP.parse(string)
+
+
+@pytest.mark.parametrize('unit', [val for key, val in u.__dict__.items()
+                                  if (isinstance(val, core.UnitBase) and
+                                      not isinstance(val, core.PrefixUnit))])
+def test_roundtrip(unit):
+    a = core.Unit(unit.to_string('generic'), format='generic')
+    b = core.Unit(unit.decompose().to_string('generic'), format='generic')
+    assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2)
+    assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2)
+
+
+@pytest.mark.parametrize('unit', [
+    val for key, val in u_format.VOUnit._units.items()
+    if (isinstance(val, core.UnitBase) and
+        not isinstance(val, core.PrefixUnit))])
+def test_roundtrip_vo_unit(unit):
+    a = core.Unit(unit.to_string('vounit'), format='vounit')
+    assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2)
+    if unit not in (u.mag, u.dB):
+        ud = unit.decompose().to_string('vounit')
+        assert '  ' not in ud
+        b = core.Unit(ud, format='vounit')
         assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2)
 
-    x = u_format.CDS
-    for key, val in x._units.items():
-        if isinstance(val, core.UnitBase) and not isinstance(val, core.PrefixUnit):
-            yield _test_roundtrip_cds, val
 
+@pytest.mark.parametrize('unit', [
+    val for key, val in u_format.Fits._units.items()
+    if (isinstance(val, core.UnitBase) and
+        not isinstance(val, core.PrefixUnit))])
+def test_roundtrip_fits(unit):
+    s = unit.to_string('fits')
+    a = core.Unit(s, format='fits')
+    assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2)
 
-def test_roundtrip_ogip():
-    def _test_roundtrip_ogip(unit):
-        a = core.Unit(unit.to_string('ogip'), format='ogip')
-        assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2)
-        try:
-            b = core.Unit(unit.decompose().to_string('ogip'), format='ogip')
-        except ValueError:  # skip mag: decomposes into dex, unknown to OGIP
-            return
-        assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2)
 
-    x = u_format.OGIP
-    for key, val in x._units.items():
-        if isinstance(val, core.UnitBase) and not isinstance(val, core.PrefixUnit):
-            yield _test_roundtrip_ogip, val
+@pytest.mark.parametrize('unit', [
+    val for key, val in u_format.CDS._units.items()
+    if (isinstance(val, core.UnitBase) and
+        not isinstance(val, core.PrefixUnit))])
+def test_roundtrip_cds(unit):
+    a = core.Unit(unit.to_string('cds'), format='cds')
+    assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2)
+    try:
+        b = core.Unit(unit.decompose().to_string('cds'), format='cds')
+    except ValueError:  # skip mag: decomposes into dex, unknown to OGIP
+        return
+    assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2)
+
+
+@pytest.mark.parametrize('unit', [
+    val for key, val in u_format.OGIP._units.items()
+    if (isinstance(val, core.UnitBase) and
+        not isinstance(val, core.PrefixUnit))])
+def test_roundtrip_ogip(unit):
+    a = core.Unit(unit.to_string('ogip'), format='ogip')
+    assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2)
+    try:
+        b = core.Unit(unit.decompose().to_string('ogip'), format='ogip')
+    except ValueError:  # skip mag: decomposes into dex, unknown to OGIP
+        return
+    assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2)
 
 
 def test_fits_units_available():
@@ -307,22 +266,16 @@ def test_latex_inline_scale():
     assert fluxunit.to_string('latex_inline') == latex_inline
 
 
-def test_format_styles():
+@pytest.mark.parametrize('format_spec, string', [
+    ('generic','erg / (cm2 s)'),
+    ('s', 'erg / (cm2 s)'),
+    ('console', '  erg  \n ------\n s cm^2'),
+    ('latex', '$\\mathrm{\\frac{erg}{s\\,cm^{2}}}$'),
+    ('latex_inline', '$\\mathrm{erg\\,s^{-1}\\,cm^{-2}}$'),
+    ('>20s','       erg / (cm2 s)')])
+def test_format_styles(format_spec, string):
     fluxunit = u.erg / (u.cm ** 2 * u.s)
-    def _test_format_styles(format_spec, s):
-        assert format(fluxunit, format_spec) == s
-
-    format_s_pairs = [
-        ('generic','erg / (cm2 s)'),
-        ('s', 'erg / (cm2 s)'),
-        ('console', '  erg  \n ------\n s cm^2'),
-        ('latex', '$\\mathrm{\\frac{erg}{s\\,cm^{2}}}$'),
-        ('latex_inline', '$\\mathrm{erg\\,s^{-1}\\,cm^{-2}}$'),
-        ('>20s','       erg / (cm2 s)'),
-    ]
-
-    for format_, s in format_s_pairs:
-        yield _test_format_styles, format_, s
+    assert format(fluxunit, format_spec) == string
 
 
 def test_flatten_to_known():
@@ -332,10 +285,9 @@ def test_flatten_to_known():
     assert myunit2.to_string('fits') == 'bit3 erg Hz-1'
 
 
-@raises(ValueError)
 def test_flatten_impossible():
     myunit = u.def_unit("FOOBAR_Two")
-    with u.add_enabled_units(myunit):
+    with u.add_enabled_units(myunit), pytest.raises(ValueError):
         myunit.to_string('fits')
 
 
@@ -434,34 +386,20 @@ def test_deprecated_did_you_mean_units():
     assert '0.1nm' in six.text_type(w[0].message)
 
 
-def test_fits_function():
+@pytest.mark.parametrize('string', ['mag(ct/s)', 'dB(mW)', 'dex(cm s**-2)'])
+def test_fits_function(string):
     # Function units cannot be written, so ensure they're not parsed either.
-    @raises(ValueError)
-    def _test_fits_grammar_fail(s):
-        print(s)
-        u_format.Fits().parse(s)
-
-    data = ['mag(ct/s)',
-            'dB(mW)',
-            'dex(cm s**-2)']
-
-    for s in data:
-        yield _test_fits_grammar_fail, s
+    with pytest.raises(ValueError):
+        print(string)
+        u_format.Fits().parse(string)
 
 
-def test_vounit_function():
+@pytest.mark.parametrize('string', ['mag(ct/s)', 'dB(mW)', 'dex(cm s**-2)'])
+def test_vounit_function(string):
     # Function units cannot be written, so ensure they're not parsed either.
-    @raises(ValueError)
-    def _test_vounit_grammar_fail(s):
-        print(s)
-        u_format.VOUnit().parse(s)
-
-    data = ['mag(ct/s)',
-            'dB(mW)',
-            'dex(cm s**-2)']
-
-    for s in data:
-        yield _test_vounit_grammar_fail, s
+    with pytest.raises(ValueError):
+        print(string)
+        u_format.VOUnit().parse(string)
 
 
 def test_vounit_binary_prefix():
index a17a076cac9df62d5d7656c41282cab26ef8dbf5..57d9a932db6ff9bf6d69d2ab28812cb4612af5b9 100644 (file)
@@ -7,84 +7,61 @@ import os
 
 import numpy as np
 
+from ...tests.helper import pytest
 from ...utils.data import get_pkg_data_filenames, get_pkg_data_contents
 from ...utils.misc import NumpyRNGContext
 
 from ... import wcs
 
+#hdr_map_file_list = list(get_pkg_data_filenames("maps", pattern="*.hdr"))
 
-def test_maps():
-    def test_map(filename):
-        header = get_pkg_data_contents(os.path.join("maps", filename))
-        wcsobj = wcs.WCS(header)
-
-        with NumpyRNGContext(123456789):
-            x = np.random.rand(2 ** 12, wcsobj.wcs.naxis)
-            world = wcsobj.wcs_pix2world(x, 1)
-            pix = wcsobj.wcs_world2pix(x, 1)
-
-    hdr_file_list = list(get_pkg_data_filenames("maps", pattern="*.hdr"))
-
-    # actually perform a test for each one
-    for filename in hdr_file_list:
-
-        # use the base name of the file, because everything we yield
-        # will show up in the test name in the pandokia report
-        filename = os.path.basename(filename)
-
-        # yield a function name and parameters to make a generated test
-        yield test_map, filename
+# use the base name of the file, because everything we yield
+# will show up in the test name in the pandokia report
+hdr_map_file_list = [os.path.basename(fname) for fname in get_pkg_data_filenames("maps", pattern="*.hdr")]
 
-    # AFTER we tested with every file that we found, check to see that we
-    # actually have the list we expect.  If N=0, we will not have performed
-    # any tests at all.  If N < n_data_files, we are missing some files,
-    # so we will have skipped some tests.  Without this check, both cases
-    # happen silently!
+# Checking the number of files before reading them in.
+# OLD COMMENTS:
+# AFTER we tested with every file that we found, check to see that we
+# actually have the list we expect.  If N=0, we will not have performed
+# any tests at all.  If N < n_data_files, we are missing some files,
+# so we will have skipped some tests.  Without this check, both cases
+# happen silently!
 
-    # how many do we expect to see?
-    n_data_files = 28
-
-    if len(hdr_file_list) != n_data_files:
-        assert False, (
-            "test_maps has wrong number data files: found {}, expected "
-            " {}".format(len(hdr_file_list), n_data_files))
-        # b.t.w.  If this assert happens, py.test reports one more test
-        # than it would have otherwise.
+def test_read_map_files():
+    # how many map files we expect to see
+    n_map_files = 28
 
+    assert len(hdr_map_file_list) == n_map_files, (
+           "test_read_map_files has wrong number data files: found {}, expected "
+           " {}".format(len(hdr_map_file_list), n_map_files))
 
-def test_spectra():
-    def test_spectrum(filename):
-        header = get_pkg_data_contents(os.path.join("spectra", filename))
+@pytest.mark.parametrize("filename", hdr_map_file_list)
+def test_map(filename):
+        header = get_pkg_data_contents(os.path.join("maps", filename))
         wcsobj = wcs.WCS(header)
+
         with NumpyRNGContext(123456789):
-            x = np.random.rand(2 ** 16, wcsobj.wcs.naxis)
+            x = np.random.rand(2 ** 12, wcsobj.wcs.naxis)
             world = wcsobj.wcs_pix2world(x, 1)
             pix = wcsobj.wcs_world2pix(x, 1)
 
-    hdr_file_list = list(get_pkg_data_filenames("spectra", pattern="*.hdr"))
+hdr_spec_file_list = [os.path.basename(fname) for fname in get_pkg_data_filenames("spectra", pattern="*.hdr")]
 
-    # actually perform a test for each one
-    for filename in hdr_file_list:
+def test_read_spec_files():
+    # how many spec files expected
+    n_spec_files = 6
 
-        # use the base name of the file, because everything we yield
-        # will show up in the test name in the pandokia report
-        filename = os.path.basename(filename)
-
-        # yield a function name and parameters to make a generated test
-        yield test_spectrum, filename
-
-    # AFTER we tested with every file that we found, check to see that we
-    # actually have the list we expect.  If N=0, we will not have performed
-    # any tests at all.  If N < n_data_files, we are missing some files,
-    # so we will have skipped some tests.  Without this check, both cases
-    # happen silently!
-
-    # how many do we expect to see?
-    n_data_files = 6
-
-    if len(hdr_file_list) != n_data_files:
-        assert False, (
+    assert len(hdr_spec_file_list) == n_spec_files, (
             "test_spectra has wrong number data files: found {}, expected "
-            " {}".format(len(hdr_file_list), n_data_files))
+            " {}".format(len(hdr_spec_file_list), n_spec_files))
         # b.t.w.  If this assert happens, py.test reports one more test
         # than it would have otherwise.
+
+@pytest.mark.parametrize("filename", hdr_spec_file_list)
+def test_spectrum(filename):
+    header = get_pkg_data_contents(os.path.join("spectra", filename))
+    wcsobj = wcs.WCS(header)
+    with NumpyRNGContext(123456789):
+        x = np.random.rand(2 ** 16, wcsobj.wcs.naxis)
+        world = wcsobj.wcs_pix2world(x, 1)
+        pix = wcsobj.wcs_world2pix(x, 1)
index cf9db9ae7242fbaba808ede2a2122a72029d34f6..12e77bc9e3d102d784d69bb3e53641a373d45e7b 100644 (file)
@@ -26,99 +26,67 @@ from ...io import fits
 from ...extern.six.moves import range
 
 
-# test_maps() is a generator
-def test_maps():
+class TestMaps(object):
+    def setup(self):
+        # get the list of the hdr files that we want to test
+        self._file_list = list(get_pkg_data_filenames("maps", pattern="*.hdr"))
 
-    # test_map() is the function that is called to perform the generated test
-    def test_map(filename):
-
-        # the test parameter is the base name of the file to use; find
-        # the file in the installed wcs test directory
-        header = get_pkg_data_contents(
-            os.path.join("maps", filename), encoding='binary')
-        wcsobj = wcs.WCS(header)
-
-        world = wcsobj.wcs_pix2world([[97, 97]], 1)
-
-        assert_array_almost_equal(world, [[285.0, -66.25]], decimal=1)
-
-        pix = wcsobj.wcs_world2pix([[285.0, -66.25]], 1)
-
-        assert_array_almost_equal(pix, [[97, 97]], decimal=0)
-
-    # get the list of the hdr files that we want to test
-    hdr_file_list = list(get_pkg_data_filenames("maps", pattern="*.hdr"))
-
-    # actually perform a test for each one
-    for filename in hdr_file_list:
-
-        # use the base name of the file, because everything we yield
-        # will show up in the test name in the pandokia report
-        filename = os.path.basename(filename)
-
-        # yield a function name and parameters to make a generated test
-        yield test_map, filename
-
-    # AFTER we tested with every file that we found, check to see that we
-    # actually have the list we expect.  If N=0, we will not have performed
-    # any tests at all.  If N < n_data_files, we are missing some files,
-    # so we will have skipped some tests.  Without this check, both cases
-    # happen silently!
-
-    # how many do we expect to see?
-    n_data_files = 28
-
-    if len(hdr_file_list) != n_data_files:
-        assert False, (
-            "test_maps has wrong number data files: found {}, expected "
-            " {}".format(len(hdr_file_list), n_data_files))
-        # b.t.w.  If this assert happens, py.test reports one more test
-        # than it would have otherwise.
+    def test_consistency(self):
+        # Check to see that we actually have the list we expect, so that we
+        # do not get in a situation where the list is empty or incomplete and
+        # the tests still seem to pass correctly.
 
+        # how many do we expect to see?
+        n_data_files = 28
 
-# test_spectra() is a generator
-def test_spectra():
-
-    # test_spectrum() is the function that is called to perform the
-    # generated test
-    def test_spectrum(filename):
-
-        # the test parameter is the base name of the file to use; find
-        # the file in the installed wcs test directory
-        header = get_pkg_data_contents(
-            os.path.join("spectra", filename), encoding='binary')
-
-        all_wcs = wcs.find_all_wcs(header)
-        assert len(all_wcs) == 9
-
-    # get the list of the hdr files that we want to test
-    hdr_file_list = list(get_pkg_data_filenames("spectra", pattern="*.hdr"))
-
-    # actually perform a test for each one
-    for filename in hdr_file_list:
-
-        # use the base name of the file, because everything we yield
-        # will show up in the test name in the pandokia report
-        filename = os.path.basename(filename)
-
-        # yield a function name and parameters to make a generated test
-        yield test_spectrum, filename
-
-    # AFTER we tested with every file that we found, check to see that we
-    # actually have the list we expect.  If N=0, we will not have performed
-    # any tests at all.  If N < n_data_files, we are missing some files,
-    # so we will have skipped some tests.  Without this check, both cases
-    # happen silently!
-
-    # how many do we expect to see?
-    n_data_files = 6
-
-    if len(hdr_file_list) != n_data_files:
-        assert False, (
+        assert len(self._file_list) == n_data_files, (
+            "test_spectra has wrong number data files: found {}, expected "
+            " {}".format(len(self._file_list), n_data_files))
+
+    def test_maps(self):
+        for filename in self._file_list:
+            # use the base name of the file, so we get more useful messages
+            # for failing tests.
+            filename = os.path.basename(filename)
+            # Now find the associated file in the installed wcs test directory.
+            header = get_pkg_data_contents(
+                os.path.join("maps", filename), encoding='binary')
+            # finally run the test.
+            wcsobj = wcs.WCS(header)
+            world = wcsobj.wcs_pix2world([[97, 97]], 1)
+            assert_array_almost_equal(world, [[285.0, -66.25]], decimal=1)
+            pix = wcsobj.wcs_world2pix([[285.0, -66.25]], 1)
+            assert_array_almost_equal(pix, [[97, 97]], decimal=0)
+
+
+class TestSpectra(object):
+    def setup(self):
+        self._file_list = list(get_pkg_data_filenames("spectra",
+                                                      pattern="*.hdr"))
+
+    def test_consistency(self):
+        # Check to see that we actually have the list we expect, so that we
+        # do not get in a situation where the list is empty or incomplete and
+        # the tests still seem to pass correctly.
+
+        # how many do we expect to see?
+        n_data_files = 6
+
+        assert len(self._file_list) == n_data_files, (
             "test_spectra has wrong number data files: found {}, expected "
-            " {}".format(len(hdr_file_list), n_data_files))
-        # b.t.w.  If this assert happens, py.test reports one more test
-        # than it would have otherwise.
+            " {}".format(len(self._file_list), n_data_files))
+
+    def test_spectra(self):
+        for filename in self._file_list:
+            # use the base name of the file, so we get more useful messages
+            # for failing tests.
+            filename = os.path.basename(filename)
+            # Now find the associated file in the installed wcs test directory.
+            header = get_pkg_data_contents(
+                os.path.join("spectra", filename), encoding='binary')
+            # finally run the test.
+            all_wcs = wcs.find_all_wcs(header)
+            assert len(all_wcs) == 9
 
 
 def test_fixes():