From: kelle Date: Sat, 7 Jan 2017 22:45:29 +0000 (-0500) Subject: Use parametrize instead of yield. X-Git-Tag: archive/raspbian/2.0.2-2+rpi1~1^2^2~2 X-Git-Url: https://dgit.raspbian.org/?a=commitdiff_plain;h=84029231f0024a745d69b10a9b2ec366d6cde2aa;p=python-astropy.git Use parametrize instead of yield. Pull request: https://github.com/astropy/astropy/pull/5678 Pull request: https://github.com/astropy/astropy/pull/5682 Gbp-Pq: Name Use-parametrize-instead-of-yield.patch --- diff --git a/astropy/io/votable/tests/vo_test.py b/astropy/io/votable/tests/vo_test.py index 587fc59..31acb9c 100644 --- a/astropy/io/votable/tests/vo_test.py +++ b/astropy/io/votable/tests/vo_test.py @@ -737,13 +737,10 @@ def table_from_scratch(): def test_open_files(): - def test_file(filename): - parse(filename, pedantic=False) - for filename in get_pkg_data_filenames('data', pattern='*.xml'): if filename.endswith('custom_datatype.xml'): continue - yield test_file, filename + parse(filename, pedantic=False) @raises(VOTableSpecError) diff --git a/astropy/units/tests/test_format.py b/astropy/units/tests/test_format.py index 9316a59..b5e9d12 100644 --- a/astropy/units/tests/test_format.py +++ b/astropy/units/tests/test_format.py @@ -13,7 +13,7 @@ from __future__ import (absolute_import, unicode_literals, division, from ...extern import six from numpy.testing.utils import assert_allclose -from ...tests.helper import raises, pytest, catch_warnings +from ...tests.helper import pytest, catch_warnings from ... import units as u from ...constants import si @@ -22,246 +22,205 @@ from .. import format as u_format from ..utils import is_effectively_unity -def test_unit_grammar(): - def _test_unit_grammar(s, unit): +@pytest.mark.parametrize('strings, unit', [ + (["m s", "m*s", "m.s"], u.m * u.s), + (["m/s", "m*s**-1", "m /s", "m / s", "m/ s"], u.m / u.s), + (["m**2", "m2", "m**(2)", "m**+2", "m+2", "m^(+2)"], u.m ** 2), + (["m**-3", "m-3", "m^(-3)", "/m3"], u.m ** -3), + (["m**(1.5)", "m(3/2)", "m**(3/2)", "m^(3/2)"], u.m ** 1.5), + (["2.54 cm"], u.Unit(u.cm * 2.54)), + (["10+8m"], u.Unit(u.m * 1e8)), + # This is the VOUnits documentation, but doesn't seem to follow the + # unity grammar (["3.45 10**(-4)Jy"], 3.45 * 1e-4 * u.Jy) + (["sqrt(m)"], u.m ** 0.5), + (["dB(mW)", "dB (mW)"], u.DecibelUnit(u.mW)), + (["mag"], u.mag), + (["mag(ct/s)"], u.MagUnit(u.ct / u.s)), + (["dex"], u.dex), + (["dex(cm s**-2)", "dex(cm/s2)"], u.DexUnit(u.cm / u.s**2))]) +def test_unit_grammar(strings, unit): + for s in strings: print(s) unit2 = u_format.Generic.parse(s) assert unit2 == unit - data = [ - (["m s", "m*s", "m.s"], u.m * u.s), - (["m/s", "m*s**-1", "m /s", "m / s", "m/ s"], u.m / u.s), - (["m**2", "m2", "m**(2)", "m**+2", "m+2", "m^(+2)"], u.m ** 2), - (["m**-3", "m-3", "m^(-3)", "/m3"], u.m ** -3), - (["m**(1.5)", "m(3/2)", "m**(3/2)", "m^(3/2)"], u.m ** 1.5), - (["2.54 cm"], u.Unit(u.cm * 2.54)), - (["10+8m"], u.Unit(u.m * 1e8)), - # This is the VOUnits documentation, but doesn't seem to follow the - # unity grammar (["3.45 10**(-4)Jy"], 3.45 * 1e-4 * u.Jy) - (["sqrt(m)"], u.m ** 0.5), - (["dB(mW)", "dB (mW)"], u.DecibelUnit(u.mW)), - (["mag"], u.mag), - (["mag(ct/s)"], u.MagUnit(u.ct / u.s)), - (["dex"], u.dex), - (["dex(cm s**-2)", "dex(cm/s2)"], u.DexUnit(u.cm / u.s**2)) - ] - - for strings, unit in data: - for s in strings: - yield _test_unit_grammar, s, unit - - -def test_unit_grammar_fail(): - @raises(ValueError) - def _test_unit_grammar_fail(s): - u_format.Generic.parse(s) - - data = ['sin( /pixel /s)', - 'mag(mag)', - 'dB(dB(mW))', - 'dex()'] - - for s in data: - yield _test_unit_grammar_fail, s - - -def test_cds_grammar(): - def _test_cds_grammar(s, unit): + +@pytest.mark.parametrize('string', ['sin( /pixel /s)', 'mag(mag)', + 'dB(dB(mW))', 'dex()']) +def test_unit_grammar_fail(string): + with pytest.raises(ValueError): + print(string) + u_format.Generic.parse(string) + +@pytest.mark.parametrize('strings, unit', [ + (["0.1nm"], u.AA), + (["mW/m2"], u.Unit(u.erg / u.cm ** 2 / u.s)), + (["mW/(m2)"], u.Unit(u.erg / u.cm ** 2 / u.s)), + (["km/s", "km.s-1"], u.km / u.s), + (["10pix/nm"], u.Unit(10 * u.pix / u.nm)), + (["1.5x10+11m"], u.Unit(1.5e11 * u.m)), + (["1.5×10+11m"], u.Unit(1.5e11 * u.m)), + (["m2"], u.m ** 2), + (["10+21m"], u.Unit(u.m * 1e21)), + (["2.54cm"], u.Unit(u.cm * 2.54)), + (["20%"], 0.20 * u.dimensionless_unscaled), + (["10+9"], 1.e9 * u.dimensionless_unscaled), + (["2x10-9"], 2.e-9 * u.dimensionless_unscaled), + (["---"], u.dimensionless_unscaled), + (["ma"], u.ma), + (["mAU"], u.mAU), + (["uarcmin"], u.uarcmin), + (["uarcsec"], u.uarcsec), + (["kbarn"], u.kbarn), + (["Gbit"], u.Gbit), + (["Gibit"], 2 ** 30 * u.bit), + (["kbyte"], u.kbyte), + (["mRy"], 0.001 * u.Ry), + (["mmag"], u.mmag), + (["Mpc"], u.Mpc), + (["Gyr"], u.Gyr), + (["°"], u.degree), + (["°/s"], u.degree / u.s), + (["Å"], u.AA), + (["Å/s"], u.AA / u.s), + (["\\h"], si.h)]) +def test_cds_grammar(strings, unit): + for s in strings: print(s) unit2 = u_format.CDS.parse(s) assert unit2 == unit - data = [ - (["0.1nm"], u.AA), - (["mW/m2"], u.Unit(u.erg / u.cm ** 2 / u.s)), - (["mW/(m2)"], u.Unit(u.erg / u.cm ** 2 / u.s)), - (["km/s", "km.s-1"], u.km / u.s), - (["10pix/nm"], u.Unit(10 * u.pix / u.nm)), - (["1.5x10+11m"], u.Unit(1.5e11 * u.m)), - (["1.5×10+11m"], u.Unit(1.5e11 * u.m)), - (["m2"], u.m ** 2), - (["10+21m"], u.Unit(u.m * 1e21)), - (["2.54cm"], u.Unit(u.cm * 2.54)), - (["20%"], 0.20 * u.dimensionless_unscaled), - (["10+9"], 1.e9 * u.dimensionless_unscaled), - (["2x10-9"], 2.e-9 * u.dimensionless_unscaled), - (["---"], u.dimensionless_unscaled), - (["ma"], u.ma), - (["mAU"], u.mAU), - (["uarcmin"], u.uarcmin), - (["uarcsec"], u.uarcsec), - (["kbarn"], u.kbarn), - (["Gbit"], u.Gbit), - (["Gibit"], 2 ** 30 * u.bit), - (["kbyte"], u.kbyte), - (["mRy"], 0.001 * u.Ry), - (["mmag"], u.mmag), - (["Mpc"], u.Mpc), - (["Gyr"], u.Gyr), - (["°"], u.degree), - (["°/s"], u.degree / u.s), - (["Å"], u.AA), - (["Å/s"], u.AA / u.s), - (["\\h"], si.h)] - - for strings, unit in data: - for s in strings: - yield _test_cds_grammar, s, unit - - -def test_cds_grammar_fail(): - @raises(ValueError) - def _test_cds_grammar_fail(s): - print(s) - u_format.CDS.parse(s) - - data = ['0.1 nm', - 'solMass(3/2)', - 'km / s', - 'km s-1', - 'pix0.1nm', - 'pix/(0.1nm)', - 'km*s', - 'km**2', - '5x8+3m', - '0.1---', - '---m', - 'm---', - 'mag(s-1)', - 'dB(mW)', - 'dex(cm s-2)'] - - for s in data: - yield _test_cds_grammar_fail, s - - -def test_ogip_grammar(): - def _test_ogip_grammar(s, unit): + +@pytest.mark.parametrize('string', [ + '0.1 nm', + 'solMass(3/2)', + 'km / s', + 'km s-1', + 'pix0.1nm', + 'pix/(0.1nm)', + 'km*s', + 'km**2', + '5x8+3m', + '0.1---', + '---m', + 'm---', + 'mag(s-1)', + 'dB(mW)', + 'dex(cm s-2)']) +def test_cds_grammar_fail(string): + with pytest.raises(ValueError): + print(string) + u_format.CDS.parse(string) + + +# These examples are taken from the EXAMPLES section of +# http://heasarc.gsfc.nasa.gov/docs/heasarc/ofwg/docs/general/ogip_93_001/ +@pytest.mark.parametrize('strings, unit', [ + (["count /s", "count/s", "count s**(-1)", "count / s", "count /s "], + u.count / u.s), + (["/pixel /s", "/(pixel * s)"], (u.pixel * u.s) ** -1), + (["count /m**2 /s /eV", "count m**(-2) * s**(-1) * eV**(-1)", + "count /(m**2 * s * eV)"], + u.count * u.m ** -2 * u.s ** -1 * u.eV ** -1), + (["erg /pixel /s /GHz", "erg /s /GHz /pixel", "erg /pixel /(s * GHz)"], + u.erg / (u.s * u.GHz * u.pixel)), + (["keV**2 /yr /angstrom", "10**(10) keV**2 /yr /m"], + # Though this is given as an example, it seems to violate the rules + # of not raising scales to powers, so I'm just excluding it + # "(10**2 MeV)**2 /yr /m" + u.keV**2 / (u.yr * u.angstrom)), + (["10**(46) erg /s", "10**46 erg /s", "10**(39) J /s", "10**(39) W", + "10**(15) YW", "YJ /fs"], + 10**46 * u.erg / u.s), + (["10**(-7) J /cm**2 /MeV", "10**(-9) J m**(-2) eV**(-1)", + "nJ m**(-2) eV**(-1)", "nJ /m**2 /eV"], + 10 ** -7 * u.J * u.cm ** -2 * u.MeV ** -1), + (["sqrt(erg /pixel /s /GHz)", "(erg /pixel /s /GHz)**(0.5)", + "(erg /pixel /s /GHz)**(1/2)", + "erg**(0.5) pixel**(-0.5) s**(-0.5) GHz**(-0.5)"], + (u.erg * u.pixel ** -1 * u.s ** -1 * u.GHz ** -1) ** 0.5), + (["(count /s) (/pixel /s)", "(count /s) * (/pixel /s)", + "count /pixel /s**2"], + (u.count / u.s) * (1.0 / (u.pixel * u.s)))]) +def test_ogip_grammar(strings, unit): + for s in strings: print(s) unit2 = u_format.OGIP.parse(s) assert unit2 == unit - # These examples are taken from the EXAMPLES section of - # http://heasarc.gsfc.nasa.gov/docs/heasarc/ofwg/docs/general/ogip_93_001/ - data = [ - (["count /s", "count/s", "count s**(-1)", "count / s", "count /s "], - u.count / u.s), - (["/pixel /s", "/(pixel * s)"], (u.pixel * u.s) ** -1), - (["count /m**2 /s /eV", "count m**(-2) * s**(-1) * eV**(-1)", - "count /(m**2 * s * eV)"], - u.count * u.m ** -2 * u.s ** -1 * u.eV ** -1), - (["erg /pixel /s /GHz", "erg /s /GHz /pixel", "erg /pixel /(s * GHz)"], - u.erg / (u.s * u.GHz * u.pixel)), - (["keV**2 /yr /angstrom", "10**(10) keV**2 /yr /m", - # Though this is given as an example, it seems to violate the rules - # of not raising scales to powers, so I'm just excluding it - # "(10**2 MeV)**2 /yr /m" - ], - u.keV**2 / (u.yr * u.angstrom)), - (["10**(46) erg /s", "10**46 erg /s", "10**(39) J /s", "10**(39) W", - "10**(15) YW", "YJ /fs"], - 10**46 * u.erg / u.s), - (["10**(-7) J /cm**2 /MeV", "10**(-9) J m**(-2) eV**(-1)", - "nJ m**(-2) eV**(-1)", "nJ /m**2 /eV"], - 10 ** -7 * u.J * u.cm ** -2 * u.MeV ** -1), - (["sqrt(erg /pixel /s /GHz)", "(erg /pixel /s /GHz)**(0.5)", - "(erg /pixel /s /GHz)**(1/2)", - "erg**(0.5) pixel**(-0.5) s**(-0.5) GHz**(-0.5)"], - (u.erg * u.pixel ** -1 * u.s ** -1 * u.GHz ** -1) ** 0.5), - (["(count /s) (/pixel /s)", "(count /s) * (/pixel /s)", - "count /pixel /s**2"], - (u.count / u.s) * (1.0 / (u.pixel * u.s)))] - - for strings, unit in data: - for s in strings: - yield _test_ogip_grammar, s, unit - - -def test_ogip_grammar_fail(): - @raises(ValueError) - def _test_ogip_grammar_fail(s): - u_format.OGIP.parse(s) - - data = ['log(photon /m**2 /s /Hz)', - 'sin( /pixel /s)', - 'log(photon /cm**2 /s /Hz) /(sin( /pixel /s))', - 'log(photon /cm**2 /s /Hz) (sin( /pixel /s))**(-1)', - 'dB(mW)', 'dex(cm/s**2)'] - - for s in data: - yield _test_ogip_grammar_fail, s - - -def test_roundtrip(): - def _test_roundtrip(unit): - a = core.Unit(unit.to_string('generic'), format='generic') - b = core.Unit(unit.decompose().to_string('generic'), format='generic') - assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2) - assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2) - - for key, val in u.__dict__.items(): - if isinstance(val, core.UnitBase) and not isinstance(val, core.PrefixUnit): - yield _test_roundtrip, val - - -def test_roundtrip_vo_unit(): - def _test_roundtrip_vo_unit(unit, skip_decompose): - a = core.Unit(unit.to_string('vounit'), format='vounit') - assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2) - if skip_decompose: - return - u = unit.decompose().to_string('vounit') - assert ' ' not in u - b = core.Unit(u, format='vounit') - assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2) - - x = u_format.VOUnit - for key, val in x._units.items(): - if isinstance(val, core.UnitBase) and not isinstance(val, core.PrefixUnit): - yield _test_roundtrip_vo_unit, val, val in (u.mag, u.dB) - - -def test_roundtrip_fits(): - def _test_roundtrip_fits(unit): - s = unit.to_string('fits') - a = core.Unit(s, format='fits') - assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2) - - for key, val in u_format.Fits._units.items(): - if isinstance(val, core.UnitBase) and not isinstance(val, core.PrefixUnit): - yield _test_roundtrip_fits, val - -def test_roundtrip_cds(): - def _test_roundtrip_cds(unit): - a = core.Unit(unit.to_string('cds'), format='cds') - assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2) - try: - b = core.Unit(unit.decompose().to_string('cds'), format='cds') - except ValueError: # skip mag: decomposes into dex, unknown to OGIP - return +@pytest.mark.parametrize('string', [ + 'log(photon /m**2 /s /Hz)', + 'sin( /pixel /s)', + 'log(photon /cm**2 /s /Hz) /(sin( /pixel /s))', + 'log(photon /cm**2 /s /Hz) (sin( /pixel /s))**(-1)', + 'dB(mW)', 'dex(cm/s**2)']) +def test_ogip_grammar_fail(string): + with pytest.raises(ValueError): + print(string) + u_format.OGIP.parse(string) + + +@pytest.mark.parametrize('unit', [val for key, val in u.__dict__.items() + if (isinstance(val, core.UnitBase) and + not isinstance(val, core.PrefixUnit))]) +def test_roundtrip(unit): + a = core.Unit(unit.to_string('generic'), format='generic') + b = core.Unit(unit.decompose().to_string('generic'), format='generic') + assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2) + assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2) + + +@pytest.mark.parametrize('unit', [ + val for key, val in u_format.VOUnit._units.items() + if (isinstance(val, core.UnitBase) and + not isinstance(val, core.PrefixUnit))]) +def test_roundtrip_vo_unit(unit): + a = core.Unit(unit.to_string('vounit'), format='vounit') + assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2) + if unit not in (u.mag, u.dB): + ud = unit.decompose().to_string('vounit') + assert ' ' not in ud + b = core.Unit(ud, format='vounit') assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2) - x = u_format.CDS - for key, val in x._units.items(): - if isinstance(val, core.UnitBase) and not isinstance(val, core.PrefixUnit): - yield _test_roundtrip_cds, val +@pytest.mark.parametrize('unit', [ + val for key, val in u_format.Fits._units.items() + if (isinstance(val, core.UnitBase) and + not isinstance(val, core.PrefixUnit))]) +def test_roundtrip_fits(unit): + s = unit.to_string('fits') + a = core.Unit(s, format='fits') + assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2) -def test_roundtrip_ogip(): - def _test_roundtrip_ogip(unit): - a = core.Unit(unit.to_string('ogip'), format='ogip') - assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2) - try: - b = core.Unit(unit.decompose().to_string('ogip'), format='ogip') - except ValueError: # skip mag: decomposes into dex, unknown to OGIP - return - assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2) - x = u_format.OGIP - for key, val in x._units.items(): - if isinstance(val, core.UnitBase) and not isinstance(val, core.PrefixUnit): - yield _test_roundtrip_ogip, val +@pytest.mark.parametrize('unit', [ + val for key, val in u_format.CDS._units.items() + if (isinstance(val, core.UnitBase) and + not isinstance(val, core.PrefixUnit))]) +def test_roundtrip_cds(unit): + a = core.Unit(unit.to_string('cds'), format='cds') + assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2) + try: + b = core.Unit(unit.decompose().to_string('cds'), format='cds') + except ValueError: # skip mag: decomposes into dex, unknown to OGIP + return + assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2) + + +@pytest.mark.parametrize('unit', [ + val for key, val in u_format.OGIP._units.items() + if (isinstance(val, core.UnitBase) and + not isinstance(val, core.PrefixUnit))]) +def test_roundtrip_ogip(unit): + a = core.Unit(unit.to_string('ogip'), format='ogip') + assert_allclose(a.decompose().scale, unit.decompose().scale, rtol=1e-2) + try: + b = core.Unit(unit.decompose().to_string('ogip'), format='ogip') + except ValueError: # skip mag: decomposes into dex, unknown to OGIP + return + assert_allclose(b.decompose().scale, unit.decompose().scale, rtol=1e-2) def test_fits_units_available(): @@ -307,22 +266,16 @@ def test_latex_inline_scale(): assert fluxunit.to_string('latex_inline') == latex_inline -def test_format_styles(): +@pytest.mark.parametrize('format_spec, string', [ + ('generic','erg / (cm2 s)'), + ('s', 'erg / (cm2 s)'), + ('console', ' erg \n ------\n s cm^2'), + ('latex', '$\\mathrm{\\frac{erg}{s\\,cm^{2}}}$'), + ('latex_inline', '$\\mathrm{erg\\,s^{-1}\\,cm^{-2}}$'), + ('>20s',' erg / (cm2 s)')]) +def test_format_styles(format_spec, string): fluxunit = u.erg / (u.cm ** 2 * u.s) - def _test_format_styles(format_spec, s): - assert format(fluxunit, format_spec) == s - - format_s_pairs = [ - ('generic','erg / (cm2 s)'), - ('s', 'erg / (cm2 s)'), - ('console', ' erg \n ------\n s cm^2'), - ('latex', '$\\mathrm{\\frac{erg}{s\\,cm^{2}}}$'), - ('latex_inline', '$\\mathrm{erg\\,s^{-1}\\,cm^{-2}}$'), - ('>20s',' erg / (cm2 s)'), - ] - - for format_, s in format_s_pairs: - yield _test_format_styles, format_, s + assert format(fluxunit, format_spec) == string def test_flatten_to_known(): @@ -332,10 +285,9 @@ def test_flatten_to_known(): assert myunit2.to_string('fits') == 'bit3 erg Hz-1' -@raises(ValueError) def test_flatten_impossible(): myunit = u.def_unit("FOOBAR_Two") - with u.add_enabled_units(myunit): + with u.add_enabled_units(myunit), pytest.raises(ValueError): myunit.to_string('fits') @@ -434,34 +386,20 @@ def test_deprecated_did_you_mean_units(): assert '0.1nm' in six.text_type(w[0].message) -def test_fits_function(): +@pytest.mark.parametrize('string', ['mag(ct/s)', 'dB(mW)', 'dex(cm s**-2)']) +def test_fits_function(string): # Function units cannot be written, so ensure they're not parsed either. - @raises(ValueError) - def _test_fits_grammar_fail(s): - print(s) - u_format.Fits().parse(s) - - data = ['mag(ct/s)', - 'dB(mW)', - 'dex(cm s**-2)'] - - for s in data: - yield _test_fits_grammar_fail, s + with pytest.raises(ValueError): + print(string) + u_format.Fits().parse(string) -def test_vounit_function(): +@pytest.mark.parametrize('string', ['mag(ct/s)', 'dB(mW)', 'dex(cm s**-2)']) +def test_vounit_function(string): # Function units cannot be written, so ensure they're not parsed either. - @raises(ValueError) - def _test_vounit_grammar_fail(s): - print(s) - u_format.VOUnit().parse(s) - - data = ['mag(ct/s)', - 'dB(mW)', - 'dex(cm s**-2)'] - - for s in data: - yield _test_vounit_grammar_fail, s + with pytest.raises(ValueError): + print(string) + u_format.VOUnit().parse(string) def test_vounit_binary_prefix(): diff --git a/astropy/wcs/tests/test_profiling.py b/astropy/wcs/tests/test_profiling.py index a17a076..57d9a93 100644 --- a/astropy/wcs/tests/test_profiling.py +++ b/astropy/wcs/tests/test_profiling.py @@ -7,84 +7,61 @@ import os import numpy as np +from ...tests.helper import pytest from ...utils.data import get_pkg_data_filenames, get_pkg_data_contents from ...utils.misc import NumpyRNGContext from ... import wcs +#hdr_map_file_list = list(get_pkg_data_filenames("maps", pattern="*.hdr")) -def test_maps(): - def test_map(filename): - header = get_pkg_data_contents(os.path.join("maps", filename)) - wcsobj = wcs.WCS(header) - - with NumpyRNGContext(123456789): - x = np.random.rand(2 ** 12, wcsobj.wcs.naxis) - world = wcsobj.wcs_pix2world(x, 1) - pix = wcsobj.wcs_world2pix(x, 1) - - hdr_file_list = list(get_pkg_data_filenames("maps", pattern="*.hdr")) - - # actually perform a test for each one - for filename in hdr_file_list: - - # use the base name of the file, because everything we yield - # will show up in the test name in the pandokia report - filename = os.path.basename(filename) - - # yield a function name and parameters to make a generated test - yield test_map, filename +# use the base name of the file, because everything we yield +# will show up in the test name in the pandokia report +hdr_map_file_list = [os.path.basename(fname) for fname in get_pkg_data_filenames("maps", pattern="*.hdr")] - # AFTER we tested with every file that we found, check to see that we - # actually have the list we expect. If N=0, we will not have performed - # any tests at all. If N < n_data_files, we are missing some files, - # so we will have skipped some tests. Without this check, both cases - # happen silently! +# Checking the number of files before reading them in. +# OLD COMMENTS: +# AFTER we tested with every file that we found, check to see that we +# actually have the list we expect. If N=0, we will not have performed +# any tests at all. If N < n_data_files, we are missing some files, +# so we will have skipped some tests. Without this check, both cases +# happen silently! - # how many do we expect to see? - n_data_files = 28 - - if len(hdr_file_list) != n_data_files: - assert False, ( - "test_maps has wrong number data files: found {}, expected " - " {}".format(len(hdr_file_list), n_data_files)) - # b.t.w. If this assert happens, py.test reports one more test - # than it would have otherwise. +def test_read_map_files(): + # how many map files we expect to see + n_map_files = 28 + assert len(hdr_map_file_list) == n_map_files, ( + "test_read_map_files has wrong number data files: found {}, expected " + " {}".format(len(hdr_map_file_list), n_map_files)) -def test_spectra(): - def test_spectrum(filename): - header = get_pkg_data_contents(os.path.join("spectra", filename)) +@pytest.mark.parametrize("filename", hdr_map_file_list) +def test_map(filename): + header = get_pkg_data_contents(os.path.join("maps", filename)) wcsobj = wcs.WCS(header) + with NumpyRNGContext(123456789): - x = np.random.rand(2 ** 16, wcsobj.wcs.naxis) + x = np.random.rand(2 ** 12, wcsobj.wcs.naxis) world = wcsobj.wcs_pix2world(x, 1) pix = wcsobj.wcs_world2pix(x, 1) - hdr_file_list = list(get_pkg_data_filenames("spectra", pattern="*.hdr")) +hdr_spec_file_list = [os.path.basename(fname) for fname in get_pkg_data_filenames("spectra", pattern="*.hdr")] - # actually perform a test for each one - for filename in hdr_file_list: +def test_read_spec_files(): + # how many spec files expected + n_spec_files = 6 - # use the base name of the file, because everything we yield - # will show up in the test name in the pandokia report - filename = os.path.basename(filename) - - # yield a function name and parameters to make a generated test - yield test_spectrum, filename - - # AFTER we tested with every file that we found, check to see that we - # actually have the list we expect. If N=0, we will not have performed - # any tests at all. If N < n_data_files, we are missing some files, - # so we will have skipped some tests. Without this check, both cases - # happen silently! - - # how many do we expect to see? - n_data_files = 6 - - if len(hdr_file_list) != n_data_files: - assert False, ( + assert len(hdr_spec_file_list) == n_spec_files, ( "test_spectra has wrong number data files: found {}, expected " - " {}".format(len(hdr_file_list), n_data_files)) + " {}".format(len(hdr_spec_file_list), n_spec_files)) # b.t.w. If this assert happens, py.test reports one more test # than it would have otherwise. + +@pytest.mark.parametrize("filename", hdr_spec_file_list) +def test_spectrum(filename): + header = get_pkg_data_contents(os.path.join("spectra", filename)) + wcsobj = wcs.WCS(header) + with NumpyRNGContext(123456789): + x = np.random.rand(2 ** 16, wcsobj.wcs.naxis) + world = wcsobj.wcs_pix2world(x, 1) + pix = wcsobj.wcs_world2pix(x, 1) diff --git a/astropy/wcs/tests/test_wcs.py b/astropy/wcs/tests/test_wcs.py index cf9db9a..12e77bc 100644 --- a/astropy/wcs/tests/test_wcs.py +++ b/astropy/wcs/tests/test_wcs.py @@ -26,99 +26,67 @@ from ...io import fits from ...extern.six.moves import range -# test_maps() is a generator -def test_maps(): +class TestMaps(object): + def setup(self): + # get the list of the hdr files that we want to test + self._file_list = list(get_pkg_data_filenames("maps", pattern="*.hdr")) - # test_map() is the function that is called to perform the generated test - def test_map(filename): - - # the test parameter is the base name of the file to use; find - # the file in the installed wcs test directory - header = get_pkg_data_contents( - os.path.join("maps", filename), encoding='binary') - wcsobj = wcs.WCS(header) - - world = wcsobj.wcs_pix2world([[97, 97]], 1) - - assert_array_almost_equal(world, [[285.0, -66.25]], decimal=1) - - pix = wcsobj.wcs_world2pix([[285.0, -66.25]], 1) - - assert_array_almost_equal(pix, [[97, 97]], decimal=0) - - # get the list of the hdr files that we want to test - hdr_file_list = list(get_pkg_data_filenames("maps", pattern="*.hdr")) - - # actually perform a test for each one - for filename in hdr_file_list: - - # use the base name of the file, because everything we yield - # will show up in the test name in the pandokia report - filename = os.path.basename(filename) - - # yield a function name and parameters to make a generated test - yield test_map, filename - - # AFTER we tested with every file that we found, check to see that we - # actually have the list we expect. If N=0, we will not have performed - # any tests at all. If N < n_data_files, we are missing some files, - # so we will have skipped some tests. Without this check, both cases - # happen silently! - - # how many do we expect to see? - n_data_files = 28 - - if len(hdr_file_list) != n_data_files: - assert False, ( - "test_maps has wrong number data files: found {}, expected " - " {}".format(len(hdr_file_list), n_data_files)) - # b.t.w. If this assert happens, py.test reports one more test - # than it would have otherwise. + def test_consistency(self): + # Check to see that we actually have the list we expect, so that we + # do not get in a situation where the list is empty or incomplete and + # the tests still seem to pass correctly. + # how many do we expect to see? + n_data_files = 28 -# test_spectra() is a generator -def test_spectra(): - - # test_spectrum() is the function that is called to perform the - # generated test - def test_spectrum(filename): - - # the test parameter is the base name of the file to use; find - # the file in the installed wcs test directory - header = get_pkg_data_contents( - os.path.join("spectra", filename), encoding='binary') - - all_wcs = wcs.find_all_wcs(header) - assert len(all_wcs) == 9 - - # get the list of the hdr files that we want to test - hdr_file_list = list(get_pkg_data_filenames("spectra", pattern="*.hdr")) - - # actually perform a test for each one - for filename in hdr_file_list: - - # use the base name of the file, because everything we yield - # will show up in the test name in the pandokia report - filename = os.path.basename(filename) - - # yield a function name and parameters to make a generated test - yield test_spectrum, filename - - # AFTER we tested with every file that we found, check to see that we - # actually have the list we expect. If N=0, we will not have performed - # any tests at all. If N < n_data_files, we are missing some files, - # so we will have skipped some tests. Without this check, both cases - # happen silently! - - # how many do we expect to see? - n_data_files = 6 - - if len(hdr_file_list) != n_data_files: - assert False, ( + assert len(self._file_list) == n_data_files, ( + "test_spectra has wrong number data files: found {}, expected " + " {}".format(len(self._file_list), n_data_files)) + + def test_maps(self): + for filename in self._file_list: + # use the base name of the file, so we get more useful messages + # for failing tests. + filename = os.path.basename(filename) + # Now find the associated file in the installed wcs test directory. + header = get_pkg_data_contents( + os.path.join("maps", filename), encoding='binary') + # finally run the test. + wcsobj = wcs.WCS(header) + world = wcsobj.wcs_pix2world([[97, 97]], 1) + assert_array_almost_equal(world, [[285.0, -66.25]], decimal=1) + pix = wcsobj.wcs_world2pix([[285.0, -66.25]], 1) + assert_array_almost_equal(pix, [[97, 97]], decimal=0) + + +class TestSpectra(object): + def setup(self): + self._file_list = list(get_pkg_data_filenames("spectra", + pattern="*.hdr")) + + def test_consistency(self): + # Check to see that we actually have the list we expect, so that we + # do not get in a situation where the list is empty or incomplete and + # the tests still seem to pass correctly. + + # how many do we expect to see? + n_data_files = 6 + + assert len(self._file_list) == n_data_files, ( "test_spectra has wrong number data files: found {}, expected " - " {}".format(len(hdr_file_list), n_data_files)) - # b.t.w. If this assert happens, py.test reports one more test - # than it would have otherwise. + " {}".format(len(self._file_list), n_data_files)) + + def test_spectra(self): + for filename in self._file_list: + # use the base name of the file, so we get more useful messages + # for failing tests. + filename = os.path.basename(filename) + # Now find the associated file in the installed wcs test directory. + header = get_pkg_data_contents( + os.path.join("spectra", filename), encoding='binary') + # finally run the test. + all_wcs = wcs.find_all_wcs(header) + assert len(all_wcs) == 9 def test_fixes():