Coverage for jetgp/full_gddegp/gddegp_utils.py: 65%
433 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
1import numpy as np
2import pyoti.core as coti
3from line_profiler import profile
4import numba
7# =============================================================================
8# Numba-accelerated helper functions for efficient matrix slicing
9# =============================================================================
11@numba.jit(nopython=True, cache=True)
12def extract_rows(content_full, row_indices, n_cols):
13 """
14 Extract rows from content_full at specified indices.
16 Parameters
17 ----------
18 content_full : ndarray of shape (n_rows_full, n_cols)
19 Source matrix.
20 row_indices : ndarray of int64
21 Row indices to extract.
22 n_cols : int
23 Number of columns.
25 Returns
26 -------
27 result : ndarray of shape (len(row_indices), n_cols)
28 Extracted rows.
29 """
30 n_rows = len(row_indices)
31 result = np.empty((n_rows, n_cols))
32 for i in range(n_rows):
33 ri = row_indices[i]
34 for j in range(n_cols):
35 result[i, j] = content_full[ri, j]
36 return result
39@numba.jit(nopython=True, cache=True)
40def extract_cols(content_full, col_indices, n_rows):
41 """
42 Extract columns from content_full at specified indices.
44 Parameters
45 ----------
46 content_full : ndarray of shape (n_rows, n_cols_full)
47 Source matrix.
48 col_indices : ndarray of int64
49 Column indices to extract.
50 n_rows : int
51 Number of rows.
53 Returns
54 -------
55 result : ndarray of shape (n_rows, len(col_indices))
56 Extracted columns.
57 """
58 n_cols = len(col_indices)
59 result = np.empty((n_rows, n_cols))
60 for i in range(n_rows):
61 for j in range(n_cols):
62 result[i, j] = content_full[i, col_indices[j]]
63 return result
66@numba.jit(nopython=True, cache=True)
67def extract_submatrix(content_full, row_indices, col_indices):
68 """
69 Extract submatrix from content_full at specified row and column indices.
70 Replaces the expensive np.ix_ operation.
72 Parameters
73 ----------
74 content_full : ndarray of shape (n_rows_full, n_cols_full)
75 Source matrix.
76 row_indices : ndarray of int64
77 Row indices to extract.
78 col_indices : ndarray of int64
79 Column indices to extract.
81 Returns
82 -------
83 result : ndarray of shape (len(row_indices), len(col_indices))
84 Extracted submatrix.
85 """
86 n_rows = len(row_indices)
87 n_cols = len(col_indices)
88 result = np.empty((n_rows, n_cols))
89 for i in range(n_rows):
90 ri = row_indices[i]
91 for j in range(n_cols):
92 result[i, j] = content_full[ri, col_indices[j]]
93 return result
96@numba.jit(nopython=True, cache=True)
97def extract_submatrix_transposed(content_full, row_indices, col_indices):
98 """
99 Extract submatrix and return its transpose.
100 Replaces content_full[np.ix_(row_indices, col_indices)].T
102 Parameters
103 ----------
104 content_full : ndarray of shape (n_rows_full, n_cols_full)
105 Source matrix.
106 row_indices : ndarray of int64
107 Row indices to extract.
108 col_indices : ndarray of int64
109 Column indices to extract.
111 Returns
112 -------
113 result : ndarray of shape (len(col_indices), len(row_indices))
114 Transposed extracted submatrix.
115 """
116 n_rows = len(row_indices)
117 n_cols = len(col_indices)
118 result = np.empty((n_cols, n_rows))
119 for i in range(n_rows):
120 ri = row_indices[i]
121 for j in range(n_cols):
122 result[j, i] = content_full[ri, col_indices[j]]
123 return result
126@numba.jit(nopython=True, cache=True)
127def extract_rows_transposed(content_full, row_indices, n_cols):
128 """
129 Extract rows and return transposed result.
130 Replaces content_full[row_indices, :].T
132 Parameters
133 ----------
134 content_full : ndarray of shape (n_rows_full, n_cols)
135 Source matrix.
136 row_indices : ndarray of int64
137 Row indices to extract.
138 n_cols : int
139 Number of columns.
141 Returns
142 -------
143 result : ndarray of shape (n_cols, len(row_indices))
144 Transposed extracted rows.
145 """
146 n_rows = len(row_indices)
147 result = np.empty((n_cols, n_rows))
148 for i in range(n_rows):
149 ri = row_indices[i]
150 for j in range(n_cols):
151 result[j, i] = content_full[ri, j]
152 return result
155@numba.jit(nopython=True, cache=True)
156def extract_cols_transposed(content_full, col_indices, n_rows):
157 """
158 Extract columns and return transposed result.
159 Replaces content_full[:, col_indices].T
161 Parameters
162 ----------
163 content_full : ndarray of shape (n_rows, n_cols_full)
164 Source matrix.
165 col_indices : ndarray of int64
166 Column indices to extract.
167 n_rows : int
168 Number of rows.
170 Returns
171 -------
172 result : ndarray of shape (len(col_indices), n_rows)
173 Transposed extracted columns.
174 """
175 n_cols = len(col_indices)
176 result = np.empty((n_cols, n_rows))
177 for i in range(n_rows):
178 for j in range(n_cols):
179 result[j, i] = content_full[i, col_indices[j]]
180 return result
183@numba.jit(nopython=True, cache=True, parallel=False)
184def extract_and_assign(content_full, row_indices, col_indices, K,
185 row_start, col_start):
186 """
187 Extract submatrix and assign directly to K.
189 Parameters
190 ----------
191 content_full : ndarray of shape (n_rows_full, n_cols_full)
192 Source matrix.
193 row_indices : ndarray of int64
194 Row indices to extract.
195 col_indices : ndarray of int64
196 Column indices to extract.
197 K : ndarray
198 Target matrix to fill.
199 row_start : int
200 Starting row index in K.
201 col_start : int
202 Starting column index in K.
203 """
204 n_rows = len(row_indices)
205 n_cols = len(col_indices)
206 for i in range(n_rows):
207 ri = row_indices[i]
208 for j in range(n_cols):
209 K[row_start + i, col_start + j] = content_full[ri, col_indices[j]]
212@numba.jit(nopython=True, cache=True, parallel=False)
213def extract_and_assign_transposed(content_full, row_indices, col_indices, K,
214 row_start, col_start):
215 """
216 Extract submatrix and assign its transpose directly to K.
217 Replaces K[...] = content_full[np.ix_(row_indices, col_indices)].T
219 Parameters
220 ----------
221 content_full : ndarray of shape (n_rows_full, n_cols_full)
222 Source matrix.
223 row_indices : ndarray of int64
224 Row indices to extract from content_full.
225 col_indices : ndarray of int64
226 Column indices to extract from content_full.
227 K : ndarray
228 Target matrix to fill.
229 row_start : int
230 Starting row index in K.
231 col_start : int
232 Starting column index in K.
233 """
234 n_rows = len(row_indices)
235 n_cols = len(col_indices)
236 for i in range(n_rows):
237 ri = row_indices[i]
238 for j in range(n_cols):
239 # Transposed assignment: K[col_idx, row_idx] = content[row_idx, col_idx]
240 K[row_start + j, col_start + i] = content_full[ri, col_indices[j]]
243@numba.jit(nopython=True, cache=True)
244def extract_rows_and_assign(content_full, row_indices, K,
245 row_start, col_start, n_cols):
246 """
247 Extract rows and assign directly to K.
249 Parameters
250 ----------
251 content_full : ndarray of shape (n_rows_full, n_cols)
252 Source matrix.
253 row_indices : ndarray of int64
254 Row indices to extract.
255 K : ndarray
256 Target matrix to fill.
257 row_start : int
258 Starting row index in K.
259 col_start : int
260 Starting column index in K.
261 n_cols : int
262 Number of columns to copy.
263 """
264 n_rows = len(row_indices)
265 for i in range(n_rows):
266 ri = row_indices[i]
267 for j in range(n_cols):
268 K[row_start + i, col_start + j] = content_full[ri, j]
271@numba.jit(nopython=True, cache=True)
272def extract_cols_and_assign(content_full, col_indices, K,
273 row_start, col_start, n_rows):
274 """
275 Extract columns and assign directly to K.
277 Parameters
278 ----------
279 content_full : ndarray of shape (n_rows, n_cols_full)
280 Source matrix.
281 col_indices : ndarray of int64
282 Column indices to extract.
283 K : ndarray
284 Target matrix to fill.
285 row_start : int
286 Starting row index in K.
287 col_start : int
288 Starting column index in K.
289 n_rows : int
290 Number of rows to copy.
291 """
292 n_cols = len(col_indices)
293 for i in range(n_rows):
294 for j in range(n_cols):
295 K[row_start + i, col_start + j] = content_full[i, col_indices[j]]
298@numba.jit(nopython=True, cache=True)
299def extract_rows_and_assign_transposed(content_full, row_indices, K,
300 row_start, col_start, n_cols):
301 """
302 Extract rows and assign transposed result directly to K.
303 Replaces K[...] = content_full[row_indices, :].T
305 Parameters
306 ----------
307 content_full : ndarray of shape (n_rows_full, n_cols)
308 Source matrix.
309 row_indices : ndarray of int64
310 Row indices to extract.
311 K : ndarray
312 Target matrix to fill.
313 row_start : int
314 Starting row index in K.
315 col_start : int
316 Starting column index in K.
317 n_cols : int
318 Number of columns in content_full.
319 """
320 n_rows = len(row_indices)
321 for i in range(n_rows):
322 ri = row_indices[i]
323 for j in range(n_cols):
324 K[row_start + j, col_start + i] = content_full[ri, j]
327@numba.jit(nopython=True, cache=True)
328def extract_cols_and_assign_transposed(content_full, col_indices, K,
329 row_start, col_start, n_rows):
330 """
331 Extract columns and assign transposed result directly to K.
332 Replaces K[...] = content_full[:, col_indices].T
334 Parameters
335 ----------
336 content_full : ndarray of shape (n_rows, n_cols_full)
337 Source matrix.
338 col_indices : ndarray of int64
339 Column indices to extract.
340 K : ndarray
341 Target matrix to fill.
342 row_start : int
343 Starting row index in K.
344 col_start : int
345 Starting column index in K.
346 n_rows : int
347 Number of rows in content_full.
348 """
349 n_cols = len(col_indices)
350 for i in range(n_rows):
351 for j in range(n_cols):
352 K[row_start + j, col_start + i] = content_full[i, col_indices[j]]
355# =============================================================================
356# Derivative index transformation utilities
357# =============================================================================
359def make_first_odd(der_indices):
360 """Transform derivative indices to use odd bases (1, 3, 5, ...)."""
361 result = []
362 for group in der_indices:
363 new_group = []
364 for pair in group:
365 first = pair[0]
366 new_group.append([2 * first - 1, pair[1]])
367 result.append(new_group)
368 return result
371def make_first_even(der_indices):
372 """Transform derivative indices to use even bases (2, 4, 6, ...)."""
373 result = []
374 for group in der_indices:
375 new_group = []
376 for pair in group:
377 first = pair[0]
378 new_group.append([2 * first, pair[1]])
379 result.append(new_group)
380 return result
383# =============================================================================
384# Difference computation functions
385# =============================================================================
386def compute_dimension_differences(k, X1, X2, n1, n2, rays_X1, rays_X2,
387 derivative_locations_X1, derivative_locations_X2,
388 e_tags_1, e_tags_2, oti_module):
389 """
390 Compute differences for a single dimension k.
391 Only perturbs points at specified derivative_locations with their corresponding rays.
393 Parameters
394 ----------
395 k : int
396 Dimension index.
397 X1, X2 : oti.array
398 Input point arrays of shape (n1, d) and (n2, d).
399 n1, n2 : int
400 Number of points in X1, X2.
401 rays_X1 : list of ndarray or None
402 rays_X1[i] has shape (d, len(derivative_locations_X1[i])).
403 rays_X2 : list of ndarray or None
404 rays_X2[i] has shape (d, len(derivative_locations_X2[i])).
405 derivative_locations_X1 : list of list
406 derivative_locations_X1[i] contains indices of X1 points with direction i.
407 derivative_locations_X2 : list of list
408 derivative_locations_X2[i] contains indices of X2 points with direction i.
409 e_tags_1, e_tags_2 : list
410 OTI basis elements for each direction.
411 oti_module : module
412 The PyOTI static module.
414 Returns
415 -------
416 diffs_k : oti.array
417 Differences for dimension k with shape (n1, n2).
418 """
419 # Build perturbation vector for X1
420 perturb_X1_values = [0.0] * n1
421 if rays_X1 is not None:
422 for dir_idx in range(len(rays_X1)):
423 locs = derivative_locations_X1[dir_idx]
424 rays = rays_X1[dir_idx]
425 for j, pt_idx in enumerate(locs):
426 perturb_X1_values[pt_idx] = perturb_X1_values[pt_idx] + e_tags_1[dir_idx] * rays[k, j]
428 # Build perturbation vector for X2
429 perturb_X2_values = [0.0] * n2
430 if rays_X2 is not None:
431 for dir_idx in range(len(rays_X2)):
432 locs = derivative_locations_X2[dir_idx]
433 rays = rays_X2[dir_idx]
434 for j, pt_idx in enumerate(locs):
435 perturb_X2_values[pt_idx] = perturb_X2_values[pt_idx] + e_tags_2[dir_idx] * rays[k, j]
437 # Convert to OTI arrays
438 perturb_X1 = oti_module.array(perturb_X1_values)
439 perturb_X2 = oti_module.array(perturb_X2_values)
441 # Tag coordinates
442 X1_k_tagged = X1[:, k] + perturb_X1
443 X2_k_tagged = X2[:, k] + perturb_X2
445 # Compute differences
446 diffs_k = oti_module.zeros((n1, n2))
447 for i in range(n1):
448 diffs_k[i, :] = X1_k_tagged[i, 0] - oti_module.transpose(X2_k_tagged[:, 0])
450 return diffs_k
453def differences_by_dim_func(X1, X2, rays_X1, rays_X2, derivative_locations_X1, derivative_locations_X2,
454 n_order, oti_module, return_deriv=True):
455 """
456 Compute dimension-wise differences with OTI tagging on both X1 and X2.
458 GDDEGP uses a dual-tag OTI scheme: X1 points are tagged with odd bases
459 (e_1, e_3, e_5, ...) and X2 points with even bases (e_2, e_4, e_6, ...).
460 This requires ``n_bases = 2 * n_direction_types``.
462 The dual-tag approach is necessary because each point can have a unique
463 directional ray, and the kernel matrix requires derivatives with respect to
464 *both* sets of directions simultaneously. In the difference X1 - X2, the
465 OTI coefficient for basis e_i at position (a, b) encodes only the ray of
466 the point that was tagged with e_i. A single-tag scheme (tagging both X1
467 and X2 with the same basis) would conflate the two rays in the difference,
468 making it impossible to recover the correct cross-derivative
469 ``v_i(a)^T H v_j(b)`` needed for K_dd blocks, and producing an asymmetric
470 K_fd block when rays vary per point.
472 Parameters
473 ----------
474 X1 : ndarray of shape (n1, d)
475 First set of input points.
476 X2 : ndarray of shape (n2, d)
477 Second set of input points.
478 rays_X1 : list of ndarray or None
479 List of ray arrays for X1. rays_X1[i] has shape (d, len(derivative_locations_X1[i])).
480 rays_X2 : list of ndarray or None
481 List of ray arrays for X2. rays_X2[i] has shape (d, len(derivative_locations_X2[i])).
482 derivative_locations_X1 : list of list
483 derivative_locations_X1[i] contains indices of X1 points with derivative direction i.
484 derivative_locations_X2 : list of list
485 derivative_locations_X2[i] contains indices of X2 points with derivative direction i.
486 n_order : int
487 Derivative order for OTI tagging.
488 oti_module : module
489 The PyOTI static module (e.g., pyoti.static.onumm4n2).
490 return_deriv : bool, optional
491 If True, use order 2*n_order for derivative-derivative blocks.
493 Returns
494 -------
495 differences_by_dim : list of oti.array
496 List of length d, each element is an (n1, n2) OTI array.
497 """
498 X1 = oti_module.array(X1)
499 X2 = oti_module.array(X2)
500 n1, d = X1.shape
501 n2, _ = X2.shape
503 # Determine number of derivative directions
504 m1 = len(rays_X1) if rays_X1 is not None else 0
505 m2 = len(rays_X2) if rays_X2 is not None else 0
506 m = max(m1, m2)
508 # Pre-compute OTI basis elements
509 e_tags_1 = []
510 e_tags_2 = []
512 if n_order == 0:
513 e_tags_1 = [0] * m
514 e_tags_2 = [0] * m
515 elif not return_deriv:
516 for i in range(m):
517 e_tags_1.append(oti_module.e((2 * i + 1), order=n_order))
518 e_tags_2.append(oti_module.e((2 * i + 2), order=n_order))
519 else:
520 for i in range(m):
521 e_tags_1.append(oti_module.e((2 * i + 1), order=2 * n_order))
522 e_tags_2.append(oti_module.e((2 * i + 2), order=2 * n_order))
524 # Compute differences for each dimension
525 differences_by_dim = []
526 for k in range(d):
527 diffs_k = compute_dimension_differences(
528 k, X1, X2, n1, n2, rays_X1, rays_X2,
529 derivative_locations_X1, derivative_locations_X2,
530 e_tags_1, e_tags_2, oti_module
531 )
532 differences_by_dim.append(diffs_k)
534 return differences_by_dim
537# =============================================================================
538# Derivative mapping utilities
539# =============================================================================
541def deriv_map(nbases, order):
542 """Create mapping from (order, index) to flattened index."""
543 k = 0
544 map_deriv = []
545 for ordi in range(order + 1):
546 ndir = coti.ndir_order(nbases, ordi)
547 map_deriv_i = [0] * ndir
548 for idx in range(ndir):
549 map_deriv_i[idx] = k
550 k += 1
551 map_deriv.append(map_deriv_i)
552 return map_deriv
555def transform_der_indices(der_indices, der_map):
556 """Transform derivative indices to flattened format."""
557 deriv_ind_transf = []
558 deriv_ind_order = []
559 for deriv in der_indices:
560 imdir = coti.imdir(deriv)
561 idx, order = imdir
562 deriv_ind_transf.append(der_map[order][idx])
563 deriv_ind_order.append(imdir)
564 return deriv_ind_transf, deriv_ind_order
567# =============================================================================
568# RBF Kernel Assembly Functions (Optimized with Numba)
569# =============================================================================
571@profile
572def rbf_kernel(
573 phi,
574 phi_exp,
575 n_order,
576 n_bases,
577 der_indices,
578 index=None
579):
580 """
581 Assembles the full GDDEGP covariance matrix with selective derivative coverage.
583 This version uses Numba-accelerated functions for efficient matrix slicing,
584 replacing expensive np.ix_ operations.
586 Parameters
587 ----------
588 phi : OTI array
589 Base kernel matrix from kernel_func(differences, length_scales).
590 phi_exp : ndarray
591 Expanded derivative array from phi.get_all_derivs().
592 n_order : int
593 Maximum derivative order.
594 n_bases : int
595 Number of OTI bases (must be even).
596 der_indices : list
597 Derivative index specifications.
598 index : list of list
599 index[i] contains indices of points with derivative direction i.
601 Returns
602 -------
603 K : ndarray
604 Kernel matrix with block structure based on derivative locations.
605 """
606 dh = coti.get_dHelp()
608 assert n_bases % 2 == 0, "n_bases must be an even number."
609 PHIrows, PHIcols = phi.shape
610 total_derivs = len(der_indices)
612 # Compute output matrix dimensions
613 n_deriv_rows = sum(len(locs) for locs in index)
614 n_deriv_cols = sum(len(locs) for locs in index)
615 n_output_rows = PHIrows + n_deriv_rows
616 n_output_cols = PHIcols + n_deriv_cols
618 der_map = deriv_map(n_bases, 2 * n_order)
620 # Pre-compute derivative index transformations
621 der_indices_even = make_first_even(der_indices)
622 der_indices_odd = make_first_odd(der_indices)
623 der_indices_tr_even, der_ind_order_even = transform_der_indices(der_indices_even, der_map)
624 der_indices_tr_odd, der_ind_order_odd = transform_der_indices(der_indices_odd, der_map)
626 # Convert index lists to numpy arrays for numba
627 index_arrays = [np.asarray(locs, dtype=np.int64) for locs in index]
629 # Compute block offsets
630 row_offsets = [0, PHIrows]
631 for i in range(total_derivs):
632 row_offsets.append(row_offsets[-1] + len(index[i]))
634 col_offsets = [0, PHIcols]
635 for i in range(total_derivs):
636 col_offsets.append(col_offsets[-1] + len(index[i]))
638 # Allocate output matrix
639 K = np.zeros((n_output_rows, n_output_cols))
641 # Fill blocks
642 for i in range(total_derivs + 1):
643 for j in range(total_derivs + 1):
645 if i == 0 and j == 0:
646 # K_ff: Full function-function block
647 K[0:PHIrows, 0:PHIcols] = phi_exp[0]
649 elif i == 0 and j > 0:
650 # K_fd: Function rows, derivative j columns
651 idx = der_indices_tr_even[j - 1]
652 col_locs = index_arrays[j - 1]
653 col_start = col_offsets[j]
655 # Use numba for efficient column extraction
656 extract_cols_and_assign(phi_exp[idx], col_locs, K,
657 0, col_start, PHIrows)
659 elif i > 0 and j == 0:
660 # K_df: Derivative i rows, function columns
661 idx = der_indices_tr_odd[i - 1]
662 row_locs = index_arrays[i - 1]
663 row_start = row_offsets[i]
665 # Use numba for efficient row extraction
666 extract_rows_and_assign(phi_exp[idx], row_locs, K,
667 row_start, 0, PHIcols)
669 else:
670 # K_dd: Derivative i rows, derivative j columns
671 imdir1 = der_ind_order_even[j - 1]
672 imdir2 = der_ind_order_odd[i - 1]
673 new_idx, new_ord = dh.mult_dir(
674 imdir1[0], imdir1[1], imdir2[0], imdir2[1])
675 idx = der_map[new_ord][new_idx]
677 row_locs = index_arrays[i - 1]
678 col_locs = index_arrays[j - 1]
679 row_start = row_offsets[i]
680 col_start = col_offsets[j]
682 # Use numba for efficient submatrix extraction (replaces np.ix_)
683 extract_and_assign(phi_exp[idx], row_locs, col_locs, K,
684 row_start, col_start)
686 return K
689@numba.jit(nopython=True, cache=True)
690def _assemble_kernel_numba(phi_exp_3d, K, n_rows_func, n_cols_func,
691 fd_flat_indices, df_flat_indices, dd_flat_indices,
692 idx_flat, idx_offsets, idx_sizes,
693 n_deriv_types, row_offsets, col_offsets):
694 """Fused numba kernel for GDDEGP K matrix assembly (no signs, even/odd bases)."""
695 # ff block
696 for r in range(n_rows_func):
697 for c in range(n_cols_func):
698 K[r, c] = phi_exp_3d[0, r, c]
699 # fd block (even indices)
700 for j in range(n_deriv_types):
701 fi = fd_flat_indices[j]
702 co = col_offsets[j]
703 off_j = idx_offsets[j]
704 sz_j = idx_sizes[j]
705 for r in range(n_rows_func):
706 for k in range(sz_j):
707 ci = idx_flat[off_j + k]
708 K[r, co + k] = phi_exp_3d[fi, r, ci]
709 # df block (odd indices)
710 for i in range(n_deriv_types):
711 fi = df_flat_indices[i]
712 ro = row_offsets[i]
713 off_i = idx_offsets[i]
714 sz_i = idx_sizes[i]
715 for k in range(sz_i):
716 ri = idx_flat[off_i + k]
717 for c in range(n_cols_func):
718 K[ro + k, c] = phi_exp_3d[fi, ri, c]
719 # dd block (even × odd)
720 for i in range(n_deriv_types):
721 ro = row_offsets[i]
722 off_i = idx_offsets[i]
723 sz_i = idx_sizes[i]
724 for j in range(n_deriv_types):
725 fi = dd_flat_indices[i, j]
726 co = col_offsets[j]
727 off_j = idx_offsets[j]
728 sz_j = idx_sizes[j]
729 for ki in range(sz_i):
730 ri = idx_flat[off_i + ki]
731 for kj in range(sz_j):
732 ci = idx_flat[off_j + kj]
733 K[ro + ki, co + kj] = phi_exp_3d[fi, ri, ci]
736@numba.jit(nopython=True, cache=True)
737def _project_W_to_phi_space(W, W_proj, n_rows_func, n_cols_func,
738 fd_flat_indices, df_flat_indices, dd_flat_indices,
739 idx_flat, idx_offsets, idx_sizes,
740 n_deriv_types, row_offsets, col_offsets):
741 """
742 Reverse of _assemble_kernel_numba: project W from K-space back into
743 phi_exp-space so that vdot(W, assemble(dphi_exp)) == vdot(W_proj, dphi_exp).
744 No-signs variant for GDDEGP even/odd bases.
745 """
746 for d in range(W_proj.shape[0]):
747 for r in range(W_proj.shape[1]):
748 for c in range(W_proj.shape[2]):
749 W_proj[d, r, c] = 0.0
750 for r in range(n_rows_func):
751 for c in range(n_cols_func):
752 W_proj[0, r, c] += W[r, c]
753 for j in range(n_deriv_types):
754 fi = fd_flat_indices[j]
755 co = col_offsets[j]
756 off_j = idx_offsets[j]
757 sz_j = idx_sizes[j]
758 for r in range(n_rows_func):
759 for k in range(sz_j):
760 ci = idx_flat[off_j + k]
761 W_proj[fi, r, ci] += W[r, co + k]
762 for i in range(n_deriv_types):
763 fi = df_flat_indices[i]
764 ro = row_offsets[i]
765 off_i = idx_offsets[i]
766 sz_i = idx_sizes[i]
767 for k in range(sz_i):
768 ri = idx_flat[off_i + k]
769 for c in range(n_cols_func):
770 W_proj[fi, ri, c] += W[ro + k, c]
771 for i in range(n_deriv_types):
772 ro = row_offsets[i]
773 off_i = idx_offsets[i]
774 sz_i = idx_sizes[i]
775 for j in range(n_deriv_types):
776 fi = dd_flat_indices[i, j]
777 co = col_offsets[j]
778 off_j = idx_offsets[j]
779 sz_j = idx_sizes[j]
780 for ki in range(sz_i):
781 ri = idx_flat[off_i + ki]
782 for kj in range(sz_j):
783 ci = idx_flat[off_j + kj]
784 W_proj[fi, ri, ci] += W[ro + ki, co + kj]
787def precompute_kernel_plan(n_order, n_bases, der_indices, powers, index):
788 """Precompute structural info for rbf_kernel_fast (GDDEGP even/odd variant)."""
789 dh = coti.get_dHelp()
790 assert n_bases % 2 == 0, "n_bases must be an even number."
791 der_map = deriv_map(n_bases, 2 * n_order)
793 n_deriv_types = len(der_indices)
794 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
796 index_sizes = np.array([len(idx) for idx in index_arrays], dtype=np.int64)
797 n_pts_with_derivs = int(index_sizes.sum())
799 idx_flat = np.concatenate(index_arrays) if n_deriv_types > 0 else np.array([], dtype=np.int64)
800 idx_offsets = np.zeros(n_deriv_types, dtype=np.int64)
801 for i in range(1, n_deriv_types):
802 idx_offsets[i] = idx_offsets[i - 1] + index_sizes[i - 1]
804 row_offsets = np.zeros(n_deriv_types, dtype=np.int64)
805 col_offsets = np.zeros(n_deriv_types, dtype=np.int64)
806 cumsum = 0
807 for i in range(n_deriv_types):
808 row_offsets[i] = cumsum
809 col_offsets[i] = cumsum
810 cumsum += index_sizes[i]
812 # Even/odd derivative transforms
813 der_indices_even = make_first_even(der_indices)
814 der_indices_odd = make_first_odd(der_indices)
815 der_indices_tr_even, der_ind_order_even = transform_der_indices(der_indices_even, der_map)
816 der_indices_tr_odd, der_ind_order_odd = transform_der_indices(der_indices_odd, der_map)
818 fd_flat_indices = np.array(der_indices_tr_even, dtype=np.int64)
819 df_flat_indices = np.array(der_indices_tr_odd, dtype=np.int64)
821 dd_flat_indices = np.empty((n_deriv_types, n_deriv_types), dtype=np.int64)
822 for i in range(n_deriv_types):
823 for j in range(n_deriv_types):
824 imdir1 = der_ind_order_even[j]
825 imdir2 = der_ind_order_odd[i]
826 new_idx, new_ord = dh.mult_dir(imdir1[0], imdir1[1], imdir2[0], imdir2[1])
827 dd_flat_indices[i, j] = der_map[new_ord][new_idx]
829 return {
830 'signs': np.ones(n_deriv_types + 1, dtype=np.float64), # unused, kept for API
831 'index_arrays': index_arrays,
832 'index_sizes': index_sizes,
833 'n_pts_with_derivs': n_pts_with_derivs,
834 'dd_flat_indices': dd_flat_indices,
835 'n_deriv_types': n_deriv_types,
836 'idx_flat': idx_flat,
837 'idx_offsets': idx_offsets,
838 'row_offsets': row_offsets,
839 'col_offsets': col_offsets,
840 'fd_flat_indices': fd_flat_indices,
841 'df_flat_indices': df_flat_indices,
842 }
845def rbf_kernel_fast(phi_exp_3d, plan, out=None):
846 """Fast kernel assembly using precomputed plan and fused numba kernel."""
847 n_rows_func = phi_exp_3d.shape[1]
848 n_cols_func = phi_exp_3d.shape[2]
849 total = n_rows_func + plan['n_pts_with_derivs']
850 if out is not None:
851 K = out
852 else:
853 K = np.empty((total, total))
855 if 'row_offsets_abs' in plan:
856 row_off = plan['row_offsets_abs']
857 col_off = plan['col_offsets_abs']
858 else:
859 row_off = plan['row_offsets'] + n_rows_func
860 col_off = plan['col_offsets'] + n_cols_func
862 _assemble_kernel_numba(
863 phi_exp_3d, K, n_rows_func, n_cols_func,
864 plan['fd_flat_indices'], plan['df_flat_indices'], plan['dd_flat_indices'],
865 plan['idx_flat'], plan['idx_offsets'], plan['index_sizes'],
866 plan['n_deriv_types'], row_off, col_off,
867 )
868 return K
871@profile
872def rbf_kernel_predictions(
873 phi,
874 phi_exp,
875 n_order,
876 n_bases,
877 der_indices,
878 return_deriv,
879 index=None,
880 common_derivs=None,
881 calc_cov=False,
882):
883 """
884 Constructs the RBF kernel matrix for predictions with selective derivative coverage.
886 This version uses Numba-accelerated functions for efficient matrix slicing.
888 Parameters
889 ----------
890 phi : OTI array
891 Base kernel matrix between test and training points.
892 phi_exp : ndarray
893 Expanded derivative array from phi.get_all_derivs().
894 n_order : int
895 Maximum derivative order.
896 n_bases : int
897 Number of OTI bases.
898 der_indices : list
899 Derivative specifications for training data.
900 return_deriv : bool
901 If True, predict derivatives at test points.
902 index : list of list
903 Training point indices for each derivative type.
904 common_derivs : list
905 Common derivative indices to predict.
906 calc_cov : bool
907 If True, computing covariance.
909 Returns
910 -------
911 K : ndarray
912 Prediction kernel matrix.
913 """
914 # Early return for covariance-only case
915 if calc_cov and not return_deriv:
916 return phi.real.T
918 dh = coti.get_dHelp()
920 n_train, n_test = phi.shape
921 n_deriv_types = len(der_indices)
922 n_deriv_types_pred = len(common_derivs) if common_derivs else 0
924 # Handle n_order = 0 case
925 if n_order == 0:
926 return phi.real.T
928 # Convert index lists to numpy arrays for numba
929 index_arrays = [np.asarray(locs, dtype=np.int64) for locs in index]
931 # Determine derivative map
932 if return_deriv:
933 der_map = deriv_map(n_bases, 2 * n_order)
934 derivative_locations_test = [np.arange(n_test, dtype=np.int64)] * n_deriv_types_pred
935 else:
936 der_map = deriv_map(n_bases, n_order)
938 # Create derivative index transformations
939 der_indices_even = make_first_even(der_indices)
940 der_indices_odd = make_first_odd(der_indices)
941 der_indices_tr_odd, der_ind_order_odd = transform_der_indices(der_indices_odd, der_map)
942 der_indices_odd_pred = make_first_odd(common_derivs) if common_derivs else []
943 der_indices_tr_odd_pred, der_ind_order_odd_pred = transform_der_indices(der_indices_odd_pred, der_map) if common_derivs else ([], [])
945 # Compute matrix dimensions
946 n_rows_func = n_test
947 if return_deriv:
948 n_rows_derivs = sum(len(locs) for locs in derivative_locations_test)
949 else:
950 n_rows_derivs = 0
951 total_rows = n_rows_func + n_rows_derivs
953 if return_deriv and calc_cov:
954 n_deriv_types = n_deriv_types_pred
955 n_cols_derivs = sum(len(locs) for locs in derivative_locations_test)
956 total_cols = n_train + n_cols_derivs
957 else:
958 n_cols_derivs = sum(len(locs) for locs in index)
959 total_cols = n_train + n_cols_derivs
961 # Compute block offsets
962 row_offsets = [0, n_test]
963 if return_deriv:
964 for i in range(n_deriv_types_pred):
965 row_offsets.append(row_offsets[-1] + len(derivative_locations_test[i]))
967 col_offsets = [0, n_train]
968 if return_deriv and calc_cov:
969 for i in range(n_deriv_types):
970 col_offsets.append(col_offsets[-1] + len(derivative_locations_test[i]))
971 else:
972 for i in range(n_deriv_types):
973 col_offsets.append(col_offsets[-1] + len(index[i]))
975 # Allocate output matrix
976 K = np.zeros((total_rows, total_cols))
977 base_shape = (n_train, n_test)
979 # Block (0,0): Function-Function (K_ff)
980 content_full = phi_exp[0].reshape(base_shape)
981 K[:n_test, :n_train] = content_full.T
983 # First Block-Row: Function-Derivative (K_fd)
984 for j in range(n_deriv_types):
985 col_locs = derivative_locations_test[j] if (return_deriv and calc_cov) else index_arrays[j]
986 col_start = col_offsets[j + 1]
988 flat_idx = der_indices_tr_odd_pred[j] if calc_cov else der_indices_tr_odd[j]
989 content_full = phi_exp[flat_idx].reshape(base_shape)
991 # Use numba for efficient row extraction with transpose
992 extract_rows_and_assign_transposed(content_full, col_locs, K,
993 0, col_start, n_test)
995 if not return_deriv:
996 return K
998 # First Block-Column: Derivative-Function (K_df)
999 der_indices_tr_even, der_ind_order_even = transform_der_indices(der_indices_even, der_map)
1000 der_indices_even_pred = make_first_even(common_derivs)
1001 der_indices_tr_even_pred, der_ind_order_even_pred = transform_der_indices(der_indices_even_pred, der_map)
1003 for i in range(n_deriv_types_pred):
1004 test_locs = derivative_locations_test[i]
1005 row_start = row_offsets[i + 1]
1007 flat_idx = der_indices_tr_even_pred[i]
1008 content_full = phi_exp[flat_idx].reshape(base_shape)
1010 # Use numba for efficient column extraction with transpose
1011 extract_cols_and_assign_transposed(content_full, test_locs, K,
1012 row_start, 0, n_train)
1014 # Inner Blocks: Derivative-Derivative (K_dd)
1015 for i in range(n_deriv_types_pred):
1016 test_locs = derivative_locations_test[i]
1017 row_start = row_offsets[i + 1]
1019 for j in range(n_deriv_types):
1020 col_locs = derivative_locations_test[j] if (return_deriv and calc_cov) else index_arrays[j]
1021 col_start = col_offsets[j + 1]
1023 imdir_train = der_ind_order_odd_pred[j] if calc_cov else der_ind_order_odd[j]
1024 imdir_test = der_ind_order_even_pred[i]
1025 new_idx, new_ord = dh.mult_dir(
1026 imdir_train[0], imdir_train[1],
1027 imdir_test[0], imdir_test[1]
1028 )
1029 flat_idx = der_map[new_ord][new_idx]
1031 content_full = phi_exp[flat_idx].reshape(base_shape)
1033 # Use numba for efficient submatrix extraction with transpose (replaces np.ix_ + .T)
1034 extract_and_assign_transposed(content_full, col_locs, test_locs, K,
1035 row_start, col_start)
1037 return K