From b99f9fd3adfc613c4205993a73c1299567c9bcde Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 5 Jan 2023 12:29:15 +0300 Subject: [PATCH] [PATCH 1/2] BUG: interpolate/RGI: upcast float32 to float64 Gbp-Pq: Name fix_cast_PR17726.patch --- scipy/interpolate/_rgi.py | 1 + scipy/interpolate/tests/test_rgi.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/scipy/interpolate/_rgi.py b/scipy/interpolate/_rgi.py index ec9f55f7..204bc3ea 100644 --- a/scipy/interpolate/_rgi.py +++ b/scipy/interpolate/_rgi.py @@ -332,6 +332,7 @@ class RegularGridInterpolator: indices, norm_distances = self._find_indices(xi.T) if (ndim == 2 and hasattr(self.values, 'dtype') and self.values.ndim == 2 and self.values.flags.writeable and + self.values.dtype in (np.float64, np.complex128) and self.values.dtype.byteorder == '='): # until cython supports const fused types, the fast path # cannot support non-writeable values diff --git a/scipy/interpolate/tests/test_rgi.py b/scipy/interpolate/tests/test_rgi.py index cb2c54cf..5ff52547 100644 --- a/scipy/interpolate/tests/test_rgi.py +++ b/scipy/interpolate/tests/test_rgi.py @@ -588,6 +588,35 @@ class TestRegularGridInterpolator: v2 = np.expand_dims(vs, axis=0) assert_allclose(v, v2, atol=1e-14, err_msg=method) + @pytest.mark.parametrize( + "dtype", + [np.float32, np.float64, np.complex64, np.complex128] + ) + @pytest.mark.parametrize("xi_dtype", [np.float32, np.float64]) + def test_float32_values(self, dtype, xi_dtype): + # regression test for gh-17718: values.dtype=float32 fails + def f(x, y): + return 2 * x**3 + 3 * y**2 + + x = np.linspace(1, 4, 11) + y = np.linspace(4, 7, 22) + + xg, yg = np.meshgrid(x, y, indexing='ij', sparse=True) + data = f(xg, yg) + + data = data.astype(dtype) + + interp = RegularGridInterpolator((x, y), data) + + pts = np.array([[2.1, 6.2], + [3.3, 5.2]], dtype=xi_dtype) + + # the values here are just what the call returns; the test checks that + # that the call succeeds at all, instead of failing with cython not + # having a float32 kernel + assert_allclose(interp(pts), [134.10469388, 153.40069388], atol=1e-7) + + class MyValue: """ Minimal indexable object -- 2.30.2