[PATCH 1/2] BUG: interpolate/RGI: upcast float32 to float64
authorEvgeni Burovski <evgeny.burovskiy@gmail.com>
Thu, 5 Jan 2023 09:29:15 +0000 (12:29 +0300)
committerDrew Parsons <dparsons@debian.org>
Sat, 18 Feb 2023 15:12:32 +0000 (15:12 +0000)
Gbp-Pq: Name fix_cast_PR17726.patch

scipy/interpolate/_rgi.py
scipy/interpolate/tests/test_rgi.py

index ec9f55f76aad233fde0ef79f5340763ee7fa6fab..204bc3eaf406cbce86fb8035727ff11f2d8ca395 100644 (file)
@@ -332,6 +332,7 @@ class RegularGridInterpolator:
             indices, norm_distances = self._find_indices(xi.T)
             if (ndim == 2 and hasattr(self.values, 'dtype') and
                     self.values.ndim == 2 and self.values.flags.writeable and
+                    self.values.dtype in (np.float64, np.complex128) and
                     self.values.dtype.byteorder == '='):
                 # until cython supports const fused types, the fast path
                 # cannot support non-writeable values
index cb2c54cf3701f41e989104c48a98ba19974efd97..5ff52547ba38e45048f8f009fd02a853ccda73b5 100644 (file)
@@ -588,6 +588,35 @@ class TestRegularGridInterpolator:
         v2 = np.expand_dims(vs, axis=0)
         assert_allclose(v, v2, atol=1e-14, err_msg=method)
 
+    @pytest.mark.parametrize(
+        "dtype",
+        [np.float32, np.float64, np.complex64, np.complex128]
+    )
+    @pytest.mark.parametrize("xi_dtype", [np.float32, np.float64])
+    def test_float32_values(self, dtype, xi_dtype):
+        # regression test for gh-17718: values.dtype=float32 fails
+        def f(x, y):
+            return 2 * x**3 + 3 * y**2
+
+        x = np.linspace(1, 4, 11)
+        y = np.linspace(4, 7, 22)
+
+        xg, yg = np.meshgrid(x, y, indexing='ij', sparse=True)
+        data = f(xg, yg)
+
+        data = data.astype(dtype)
+
+        interp = RegularGridInterpolator((x, y), data)
+
+        pts = np.array([[2.1, 6.2],
+                        [3.3, 5.2]], dtype=xi_dtype)
+
+        # the values here are just what the call returns; the test checks that
+        # that the call succeeds at all, instead of failing with cython not
+        # having a float32 kernel
+        assert_allclose(interp(pts), [134.10469388, 153.40069388], atol=1e-7)
+
+
 class MyValue:
     """
     Minimal indexable object