Coverage for jetgp/full_gddegp/wgddegp_utils.py: 68%
470 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
1import numpy as np
2from line_profiler import profile
3import pyoti.core as coti
4import numba
7# =============================================================================
8# Numba-accelerated helper functions for efficient matrix slicing
9# =============================================================================
11@numba.jit(nopython=True, cache=True)
12def extract_rows(content_full, row_indices, n_cols):
13 """
14 Extract rows from content_full at specified indices.
16 Parameters
17 ----------
18 content_full : ndarray of shape (n_rows_full, n_cols)
19 Source matrix.
20 row_indices : ndarray of int64
21 Row indices to extract.
22 n_cols : int
23 Number of columns.
25 Returns
26 -------
27 result : ndarray of shape (len(row_indices), n_cols)
28 Extracted rows.
29 """
30 n_rows = len(row_indices)
31 result = np.empty((n_rows, n_cols))
32 for i in range(n_rows):
33 ri = row_indices[i]
34 for j in range(n_cols):
35 result[i, j] = content_full[ri, j]
36 return result
39@numba.jit(nopython=True, cache=True)
40def extract_cols(content_full, col_indices, n_rows):
41 """
42 Extract columns from content_full at specified indices.
44 Parameters
45 ----------
46 content_full : ndarray of shape (n_rows, n_cols_full)
47 Source matrix.
48 col_indices : ndarray of int64
49 Column indices to extract.
50 n_rows : int
51 Number of rows.
53 Returns
54 -------
55 result : ndarray of shape (n_rows, len(col_indices))
56 Extracted columns.
57 """
58 n_cols = len(col_indices)
59 result = np.empty((n_rows, n_cols))
60 for i in range(n_rows):
61 for j in range(n_cols):
62 result[i, j] = content_full[i, col_indices[j]]
63 return result
66@numba.jit(nopython=True, cache=True)
67def extract_submatrix(content_full, row_indices, col_indices):
68 """
69 Extract submatrix from content_full at specified row and column indices.
70 Replaces the expensive np.ix_ operation.
72 Parameters
73 ----------
74 content_full : ndarray of shape (n_rows_full, n_cols_full)
75 Source matrix.
76 row_indices : ndarray of int64
77 Row indices to extract.
78 col_indices : ndarray of int64
79 Column indices to extract.
81 Returns
82 -------
83 result : ndarray of shape (len(row_indices), len(col_indices))
84 Extracted submatrix.
85 """
86 n_rows = len(row_indices)
87 n_cols = len(col_indices)
88 result = np.empty((n_rows, n_cols))
89 for i in range(n_rows):
90 ri = row_indices[i]
91 for j in range(n_cols):
92 result[i, j] = content_full[ri, col_indices[j]]
93 return result
96@numba.jit(nopython=True, cache=True)
97def extract_submatrix_transposed(content_full, row_indices, col_indices):
98 """
99 Extract submatrix and return its transpose.
100 Replaces content_full[np.ix_(row_indices, col_indices)].T
102 Parameters
103 ----------
104 content_full : ndarray of shape (n_rows_full, n_cols_full)
105 Source matrix.
106 row_indices : ndarray of int64
107 Row indices to extract.
108 col_indices : ndarray of int64
109 Column indices to extract.
111 Returns
112 -------
113 result : ndarray of shape (len(col_indices), len(row_indices))
114 Transposed extracted submatrix.
115 """
116 n_rows = len(row_indices)
117 n_cols = len(col_indices)
118 result = np.empty((n_cols, n_rows))
119 for i in range(n_rows):
120 ri = row_indices[i]
121 for j in range(n_cols):
122 result[j, i] = content_full[ri, col_indices[j]]
123 return result
126@numba.jit(nopython=True, cache=True)
127def extract_rows_transposed(content_full, row_indices, n_cols):
128 """
129 Extract rows and return transposed result.
130 Replaces content_full[row_indices, :].T
132 Parameters
133 ----------
134 content_full : ndarray of shape (n_rows_full, n_cols)
135 Source matrix.
136 row_indices : ndarray of int64
137 Row indices to extract.
138 n_cols : int
139 Number of columns.
141 Returns
142 -------
143 result : ndarray of shape (n_cols, len(row_indices))
144 Transposed extracted rows.
145 """
146 n_rows = len(row_indices)
147 result = np.empty((n_cols, n_rows))
148 for i in range(n_rows):
149 ri = row_indices[i]
150 for j in range(n_cols):
151 result[j, i] = content_full[ri, j]
152 return result
155@numba.jit(nopython=True, cache=True)
156def extract_cols_transposed(content_full, col_indices, n_rows):
157 """
158 Extract columns and return transposed result.
159 Replaces content_full[:, col_indices].T
161 Parameters
162 ----------
163 content_full : ndarray of shape (n_rows, n_cols_full)
164 Source matrix.
165 col_indices : ndarray of int64
166 Column indices to extract.
167 n_rows : int
168 Number of rows.
170 Returns
171 -------
172 result : ndarray of shape (len(col_indices), n_rows)
173 Transposed extracted columns.
174 """
175 n_cols = len(col_indices)
176 result = np.empty((n_cols, n_rows))
177 for i in range(n_rows):
178 for j in range(n_cols):
179 result[j, i] = content_full[i, col_indices[j]]
180 return result
183@numba.jit(nopython=True, cache=True, parallel=False)
184def extract_and_assign(content_full, row_indices, col_indices, K,
185 row_start, col_start):
186 """
187 Extract submatrix and assign directly to K.
189 Parameters
190 ----------
191 content_full : ndarray of shape (n_rows_full, n_cols_full)
192 Source matrix.
193 row_indices : ndarray of int64
194 Row indices to extract.
195 col_indices : ndarray of int64
196 Column indices to extract.
197 K : ndarray
198 Target matrix to fill.
199 row_start : int
200 Starting row index in K.
201 col_start : int
202 Starting column index in K.
203 """
204 n_rows = len(row_indices)
205 n_cols = len(col_indices)
206 for i in range(n_rows):
207 ri = row_indices[i]
208 for j in range(n_cols):
209 K[row_start + i, col_start + j] = content_full[ri, col_indices[j]]
212@numba.jit(nopython=True, cache=True, parallel=False)
213def extract_and_assign_transposed(content_full, row_indices, col_indices, K,
214 row_start, col_start):
215 """
216 Extract submatrix and assign its transpose directly to K.
217 Replaces K[...] = content_full[np.ix_(row_indices, col_indices)].T
219 Parameters
220 ----------
221 content_full : ndarray of shape (n_rows_full, n_cols_full)
222 Source matrix.
223 row_indices : ndarray of int64
224 Row indices to extract from content_full.
225 col_indices : ndarray of int64
226 Column indices to extract from content_full.
227 K : ndarray
228 Target matrix to fill.
229 row_start : int
230 Starting row index in K.
231 col_start : int
232 Starting column index in K.
233 """
234 n_rows = len(row_indices)
235 n_cols = len(col_indices)
236 for i in range(n_rows):
237 ri = row_indices[i]
238 for j in range(n_cols):
239 # Transposed assignment: K[col_idx, row_idx] = content[row_idx, col_idx]
240 K[row_start + j, col_start + i] = content_full[ri, col_indices[j]]
243@numba.jit(nopython=True, cache=True)
244def extract_rows_and_assign(content_full, row_indices, K,
245 row_start, col_start, n_cols):
246 """
247 Extract rows and assign directly to K.
249 Parameters
250 ----------
251 content_full : ndarray of shape (n_rows_full, n_cols)
252 Source matrix.
253 row_indices : ndarray of int64
254 Row indices to extract.
255 K : ndarray
256 Target matrix to fill.
257 row_start : int
258 Starting row index in K.
259 col_start : int
260 Starting column index in K.
261 n_cols : int
262 Number of columns to copy.
263 """
264 n_rows = len(row_indices)
265 for i in range(n_rows):
266 ri = row_indices[i]
267 for j in range(n_cols):
268 K[row_start + i, col_start + j] = content_full[ri, j]
271@numba.jit(nopython=True, cache=True)
272def extract_cols_and_assign(content_full, col_indices, K,
273 row_start, col_start, n_rows):
274 """
275 Extract columns and assign directly to K.
277 Parameters
278 ----------
279 content_full : ndarray of shape (n_rows, n_cols_full)
280 Source matrix.
281 col_indices : ndarray of int64
282 Column indices to extract.
283 K : ndarray
284 Target matrix to fill.
285 row_start : int
286 Starting row index in K.
287 col_start : int
288 Starting column index in K.
289 n_rows : int
290 Number of rows to copy.
291 """
292 n_cols = len(col_indices)
293 for i in range(n_rows):
294 for j in range(n_cols):
295 K[row_start + i, col_start + j] = content_full[i, col_indices[j]]
298@numba.jit(nopython=True, cache=True)
299def extract_rows_and_assign_transposed(content_full, row_indices, K,
300 row_start, col_start, n_cols):
301 """
302 Extract rows and assign transposed result directly to K.
303 Replaces K[...] = content_full[row_indices, :].T
305 Parameters
306 ----------
307 content_full : ndarray of shape (n_rows_full, n_cols)
308 Source matrix.
309 row_indices : ndarray of int64
310 Row indices to extract.
311 K : ndarray
312 Target matrix to fill.
313 row_start : int
314 Starting row index in K.
315 col_start : int
316 Starting column index in K.
317 n_cols : int
318 Number of columns in content_full.
319 """
320 n_rows = len(row_indices)
321 for i in range(n_rows):
322 ri = row_indices[i]
323 for j in range(n_cols):
324 K[row_start + j, col_start + i] = content_full[ri, j]
327@numba.jit(nopython=True, cache=True)
328def extract_cols_and_assign_transposed(content_full, col_indices, K,
329 row_start, col_start, n_rows):
330 """
331 Extract columns and assign transposed result directly to K.
332 Replaces K[...] = content_full[:, col_indices].T
334 Parameters
335 ----------
336 content_full : ndarray of shape (n_rows, n_cols_full)
337 Source matrix.
338 col_indices : ndarray of int64
339 Column indices to extract.
340 K : ndarray
341 Target matrix to fill.
342 row_start : int
343 Starting row index in K.
344 col_start : int
345 Starting column index in K.
346 n_rows : int
347 Number of rows in content_full.
348 """
349 n_cols = len(col_indices)
350 for i in range(n_rows):
351 for j in range(n_cols):
352 K[row_start + j, col_start + i] = content_full[i, col_indices[j]]
354# =============================================================================
355# Difference computation functions
356# =============================================================================
358def compute_dimension_differences(k, X1, X2, n1, n2, rays_X1, rays_X2,
359 derivative_locations_X1, derivative_locations_X2,
360 e_tags_1, e_tags_2, oti_module):
361 """
362 Compute differences for a single dimension k.
363 Only perturbs points at specified derivative_locations with their corresponding rays.
365 Parameters
366 ----------
367 k : int
368 Dimension index
369 X1, X2 : oti.array
370 Input point arrays of shape (n1, d) and (n2, d)
371 n1, n2 : int
372 Number of points in X1, X2
373 rays_X1 : list of ndarray or None
374 rays_X1[i] has shape (d, len(derivative_locations_X1[i]))
375 Column j corresponds to point derivative_locations_X1[i][j]
376 rays_X2 : list of ndarray or None
377 rays_X2[i] has shape (d, len(derivative_locations_X2[i]))
378 derivative_locations_X1 : list of list
379 derivative_locations_X1[i] contains indices of X1 points with direction i
380 derivative_locations_X2 : list of list
381 derivative_locations_X2[i] contains indices of X2 points with direction i
382 e_tags_1, e_tags_2 : list
383 OTI basis elements for each direction
384 oti_module : module
385 The PyOTI static module (e.g., pyoti.static.onumm4n2).
387 Returns
388 -------
389 diffs_k : oti.array
390 Differences for dimension k with shape (n1, n2).
391 """
392 # Build perturbation vector for X1
393 perturb_X1_values = [0.0] * n1
394 if rays_X1 is not None:
395 for dir_idx in range(len(rays_X1)):
396 locs = derivative_locations_X1[dir_idx]
397 rays = rays_X1[dir_idx]
398 for j, pt_idx in enumerate(locs):
399 perturb_X1_values[pt_idx] = perturb_X1_values[pt_idx] + e_tags_1[dir_idx] * rays[k, j]
401 # Build perturbation vector for X2
402 perturb_X2_values = [0.0] * n2
403 if rays_X2 is not None:
404 for dir_idx in range(len(rays_X2)):
405 locs = derivative_locations_X2[dir_idx]
406 rays = rays_X2[dir_idx]
407 for j, pt_idx in enumerate(locs):
408 perturb_X2_values[pt_idx] = perturb_X2_values[pt_idx] + e_tags_2[dir_idx] * rays[k, j]
410 # Convert to OTI arrays
411 perturb_X1 = oti_module.array(perturb_X1_values)
412 perturb_X2 = oti_module.array(perturb_X2_values)
414 # Tag coordinates
415 X1_k_tagged = X1[:, k] + perturb_X1
416 X2_k_tagged = X2[:, k] + perturb_X2
418 # Compute differences
419 diffs_k = oti_module.zeros((n1, n2))
420 for i in range(n1):
421 diffs_k[i, :] = X1_k_tagged[i, 0] - oti_module.transpose(X2_k_tagged[:, 0])
423 return diffs_k
426def differences_by_dim_func(X1, X2, rays_X1, rays_X2, derivative_locations_X1, derivative_locations_X2,
427 n_order, oti_module, return_deriv=True):
428 """
429 Compute dimension-wise differences with OTI tagging on both X1 and X2.
430 Only perturbs points at specified derivative_locations with their corresponding rays.
432 Parameters
433 ----------
434 X1 : ndarray of shape (n1, d)
435 First set of input points
436 X2 : ndarray of shape (n2, d)
437 Second set of input points
438 rays_X1 : list of ndarray or None
439 List of ray arrays for X1. rays_X1[i] has shape (d, len(derivative_locations_X1[i]))
440 where column j contains the ray direction for point derivative_locations_X1[i][j]
441 rays_X2 : list of ndarray or None
442 List of ray arrays for X2. rays_X2[i] has shape (d, len(derivative_locations_X2[i]))
443 derivative_locations_X1 : list of list
444 derivative_locations_X1[i] contains indices of X1 points that have derivative direction i
445 derivative_locations_X2 : list of list
446 derivative_locations_X2[i] contains indices of X2 points that have derivative direction i
447 n_order : int
448 Derivative order for OTI tagging
449 oti_module : module
450 The PyOTI static module (e.g., pyoti.static.onumm4n2).
451 return_deriv : bool, optional
452 If True, use order 2*n_order (for training kernel with derivative-derivative blocks)
453 If False, use order n_order (for prediction without derivative outputs)
455 Returns
456 -------
457 differences_by_dim : list of oti.array
458 List of length d, each element is an (n1, n2) OTI array of differences for that dimension
459 """
460 X1 = oti_module.array(X1)
461 X2 = oti_module.array(X2)
462 n1, d = X1.shape
463 n2, _ = X2.shape
465 # Determine number of derivative directions from rays arrays
466 m1 = len(rays_X1) if rays_X1 is not None else 0
467 m2 = len(rays_X2) if rays_X2 is not None else 0
468 m = max(m1, m2)
470 # Pre-compute OTI basis elements
471 e_tags_1 = []
472 e_tags_2 = []
474 if n_order == 0:
475 e_tags_1 = [0] * m
476 e_tags_2 = [0] * m
477 elif not return_deriv:
478 for i in range(m):
479 e_tags_1.append(oti_module.e((2 * i + 1), order=n_order))
480 e_tags_2.append(oti_module.e((2 * i + 2), order=n_order))
481 else:
482 for i in range(m):
483 e_tags_1.append(oti_module.e((2 * i + 1), order=2 * n_order))
484 e_tags_2.append(oti_module.e((2 * i + 2), order=2 * n_order))
486 # Compute differences for each dimension
487 differences_by_dim = []
488 for k in range(d):
489 diffs_k = compute_dimension_differences(
490 k, X1, X2, n1, n2, rays_X1, rays_X2,
491 derivative_locations_X1, derivative_locations_X2,
492 e_tags_1, e_tags_2, oti_module
493 )
494 differences_by_dim.append(diffs_k)
496 return differences_by_dim
497# =============================================================================
498# Derivative index transformation utilities
499# =============================================================================
501def make_first_odd(der_indices):
502 """Transform derivative indices to use odd bases (1, 3, 5, ...)."""
503 result = []
504 for group in der_indices:
505 new_group = []
506 for pair in group:
507 first = pair[0]
508 new_group.append([2 * first - 1, pair[1]])
509 result.append(new_group)
510 return result
513def make_first_even(der_indices):
514 """Transform derivative indices to use even bases (2, 4, 6, ...)."""
515 result = []
516 for group in der_indices:
517 new_group = []
518 for pair in group:
519 first = pair[0]
520 new_group.append([2 * first, pair[1]])
521 result.append(new_group)
522 return result
526# =============================================================================
527# Derivative mapping utilities
528# =============================================================================
530def deriv_map(nbases, order):
531 """Create mapping from (order, index) to flattened index."""
532 k = 0
533 map_deriv = []
534 for ordi in range(order + 1):
535 ndir = coti.ndir_order(nbases, ordi)
536 map_deriv_i = [0] * ndir
537 for idx in range(ndir):
538 map_deriv_i[idx] = k
539 k += 1
540 map_deriv.append(map_deriv_i)
541 return map_deriv
544def transform_der_indices(der_indices, der_map):
545 """Transform derivative indices to flattened format."""
546 deriv_ind_transf = []
547 deriv_ind_order = []
548 for deriv in der_indices:
549 imdir = coti.imdir(deriv)
550 idx, order = imdir
551 deriv_ind_transf.append(der_map[order][idx])
552 deriv_ind_order.append(imdir)
553 return deriv_ind_transf, deriv_ind_order
556# =============================================================================
557# RBF Kernel Assembly Functions (Optimized with Numba)
558# =============================================================================
560@profile
561def rbf_kernel(
562 phi,
563 phi_exp,
564 n_order,
565 n_bases,
566 der_indices,
567 powers,
568 index=-1
569):
570 """
571 Assembles the full GDDEGP covariance matrix with support for selective
572 derivative coverage via derivative_locations.
574 This version uses Numba-accelerated functions for efficient matrix slicing,
575 replacing expensive np.ix_ operations.
577 Parameters
578 ----------
579 phi : OTI array
580 Base kernel matrix from kernel_func(differences, length_scales).
581 phi_exp : ndarray
582 Expanded derivative array from phi.get_all_derivs().
583 n_order : int
584 Maximum derivative order.
585 n_bases : int
586 Number of OTI bases (must be even).
587 der_indices : list
588 Derivative index specifications.
589 powers : list of int
590 Powers of (-1) applied to each term (unused but kept for API consistency).
591 index : list of list
592 index[i] contains indices of points with derivative direction i.
594 Returns
595 -------
596 K : ndarray
597 Kernel matrix with block structure based on derivative_locations.
598 """
599 dh = coti.get_dHelp()
601 highest_order = n_order
602 if n_order == 0:
603 n_bases = 0
604 phi_exp = phi.real
605 phi_exp = phi_exp[np.newaxis, :, :]
606 else:
607 n_bases = phi.get_active_bases()[-1]
608 phi_exp = phi.get_all_derivs(n_bases, 2 * highest_order)
609 assert n_bases % 2 == 0, "n_bases must be an even number."
610 PHIrows, PHIcols = phi.shape
611 total_derivs = len(der_indices)
613 # Compute output matrix dimensions
614 n_deriv_rows = sum(len(locs) for locs in index)
615 n_deriv_cols = sum(len(locs) for locs in index)
616 n_output_rows = PHIrows + n_deriv_rows
617 n_output_cols = PHIcols + n_deriv_cols
620 der_map = deriv_map(n_bases, 2 * highest_order)
622 row_iters = total_derivs + 1
623 col_iters = total_derivs + 1
625 # Pre-compute derivative index transformations
626 der_indices_even = make_first_even(der_indices)
627 der_indices_odd = make_first_odd(der_indices)
628 der_indices_tr_even, der_ind_order_even = transform_der_indices(der_indices_even, der_map)
629 der_indices_tr_odd, der_ind_order_odd = transform_der_indices(der_indices_odd, der_map)
631 # Convert index lists to numpy arrays for numba
632 index_arrays = [np.asarray(locs, dtype=np.int64) for locs in index]
634 # Compute block offsets
635 row_offsets = [0, PHIrows]
636 for i in range(total_derivs):
637 row_offsets.append(row_offsets[-1] + len(index[i]))
639 col_offsets = [0, PHIcols]
640 for i in range(total_derivs):
641 col_offsets.append(col_offsets[-1] + len(index[i]))
643 # Allocate output matrix
644 K = np.zeros((n_output_rows, n_output_cols))
646 # Fill blocks
647 for i in range(row_iters):
648 for j in range(col_iters):
650 if i == 0 and j == 0:
651 # K_ff: Full function-function block (all points)
652 idx = 0
653 K[0:PHIrows, 0:PHIcols] = phi_exp[idx]
655 elif i == 0 and j > 0:
656 # K_fd: Function rows (all), derivative j columns (at derivative_locations[j-1])
657 idx = der_indices_tr_even[j - 1]
658 col_locs = index_arrays[j - 1]
659 col_start = col_offsets[j]
661 # Use numba for efficient column extraction
662 extract_cols_and_assign(phi_exp[idx], col_locs, K,
663 0, col_start, PHIrows)
665 elif i > 0 and j == 0:
666 # K_df: Derivative i rows (at derivative_locations[i-1]), function columns (all)
667 idx = der_indices_tr_odd[i - 1]
668 row_locs = index_arrays[i - 1]
669 row_start = row_offsets[i]
671 # Use numba for efficient row extraction
672 extract_rows_and_assign(phi_exp[idx], row_locs, K,
673 row_start, 0, PHIcols)
675 else:
676 # K_dd: Derivative i rows, derivative j columns
677 imdir1 = der_ind_order_even[j - 1]
678 imdir2 = der_ind_order_odd[i - 1]
679 new_idx, new_ord = dh.mult_dir(
680 imdir1[0], imdir1[1], imdir2[0], imdir2[1])
681 idx = der_map[new_ord][new_idx]
683 row_locs = index_arrays[i - 1]
684 col_locs = index_arrays[j - 1]
685 row_start = row_offsets[i]
686 col_start = col_offsets[j]
688 # Use numba for efficient submatrix extraction (replaces np.ix_)
689 extract_and_assign(phi_exp[idx], row_locs, col_locs, K,
690 row_start, col_start)
692 return K
695@numba.jit(nopython=True, cache=True)
696def _assemble_kernel_numba(phi_exp_3d, K, n_rows_func, n_cols_func,
697 fd_flat_indices, df_flat_indices, dd_flat_indices,
698 idx_flat, idx_offsets, idx_sizes,
699 n_deriv_types, row_offsets, col_offsets):
700 """Fused numba kernel for GDDEGP K matrix assembly (no signs, even/odd bases)."""
701 # ff block
702 for r in range(n_rows_func):
703 for c in range(n_cols_func):
704 K[r, c] = phi_exp_3d[0, r, c]
705 # fd block (even indices)
706 for j in range(n_deriv_types):
707 fi = fd_flat_indices[j]
708 co = col_offsets[j]
709 off_j = idx_offsets[j]
710 sz_j = idx_sizes[j]
711 for r in range(n_rows_func):
712 for k in range(sz_j):
713 ci = idx_flat[off_j + k]
714 K[r, co + k] = phi_exp_3d[fi, r, ci]
715 # df block (odd indices)
716 for i in range(n_deriv_types):
717 fi = df_flat_indices[i]
718 ro = row_offsets[i]
719 off_i = idx_offsets[i]
720 sz_i = idx_sizes[i]
721 for k in range(sz_i):
722 ri = idx_flat[off_i + k]
723 for c in range(n_cols_func):
724 K[ro + k, c] = phi_exp_3d[fi, ri, c]
725 # dd block (even × odd)
726 for i in range(n_deriv_types):
727 ro = row_offsets[i]
728 off_i = idx_offsets[i]
729 sz_i = idx_sizes[i]
730 for j in range(n_deriv_types):
731 fi = dd_flat_indices[i, j]
732 co = col_offsets[j]
733 off_j = idx_offsets[j]
734 sz_j = idx_sizes[j]
735 for ki in range(sz_i):
736 ri = idx_flat[off_i + ki]
737 for kj in range(sz_j):
738 ci = idx_flat[off_j + kj]
739 K[ro + ki, co + kj] = phi_exp_3d[fi, ri, ci]
742@numba.jit(nopython=True, cache=True)
743def _project_W_to_phi_space(W, W_proj, n_rows_func, n_cols_func,
744 fd_flat_indices, df_flat_indices, dd_flat_indices,
745 idx_flat, idx_offsets, idx_sizes,
746 n_deriv_types, row_offsets, col_offsets):
747 """
748 Reverse of _assemble_kernel_numba: project W from K-space back into
749 phi_exp-space so that vdot(W, assemble(dphi_exp)) == vdot(W_proj, dphi_exp).
750 No-signs variant for WGDDEGP even/odd bases.
751 """
752 for d in range(W_proj.shape[0]):
753 for r in range(W_proj.shape[1]):
754 for c in range(W_proj.shape[2]):
755 W_proj[d, r, c] = 0.0
756 for r in range(n_rows_func):
757 for c in range(n_cols_func):
758 W_proj[0, r, c] += W[r, c]
759 for j in range(n_deriv_types):
760 fi = fd_flat_indices[j]
761 co = col_offsets[j]
762 off_j = idx_offsets[j]
763 sz_j = idx_sizes[j]
764 for r in range(n_rows_func):
765 for k in range(sz_j):
766 ci = idx_flat[off_j + k]
767 W_proj[fi, r, ci] += W[r, co + k]
768 for i in range(n_deriv_types):
769 fi = df_flat_indices[i]
770 ro = row_offsets[i]
771 off_i = idx_offsets[i]
772 sz_i = idx_sizes[i]
773 for k in range(sz_i):
774 ri = idx_flat[off_i + k]
775 for c in range(n_cols_func):
776 W_proj[fi, ri, c] += W[ro + k, c]
777 for i in range(n_deriv_types):
778 ro = row_offsets[i]
779 off_i = idx_offsets[i]
780 sz_i = idx_sizes[i]
781 for j in range(n_deriv_types):
782 fi = dd_flat_indices[i, j]
783 co = col_offsets[j]
784 off_j = idx_offsets[j]
785 sz_j = idx_sizes[j]
786 for ki in range(sz_i):
787 ri = idx_flat[off_i + ki]
788 for kj in range(sz_j):
789 ci = idx_flat[off_j + kj]
790 W_proj[fi, ri, ci] += W[ro + ki, co + kj]
793def precompute_kernel_plan(n_order, n_bases, der_indices, powers, index):
794 """Precompute structural info for rbf_kernel_fast (GDDEGP even/odd variant)."""
795 dh = coti.get_dHelp()
796 assert n_bases % 2 == 0, "n_bases must be an even number."
797 der_map = deriv_map(n_bases, 2 * n_order)
799 n_deriv_types = len(der_indices)
800 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
802 index_sizes = np.array([len(idx) for idx in index_arrays], dtype=np.int64)
803 n_pts_with_derivs = int(index_sizes.sum())
805 idx_flat = np.concatenate(index_arrays) if n_deriv_types > 0 else np.array([], dtype=np.int64)
806 idx_offsets = np.zeros(n_deriv_types, dtype=np.int64)
807 for i in range(1, n_deriv_types):
808 idx_offsets[i] = idx_offsets[i - 1] + index_sizes[i - 1]
810 row_offsets = np.zeros(n_deriv_types, dtype=np.int64)
811 col_offsets = np.zeros(n_deriv_types, dtype=np.int64)
812 cumsum = 0
813 for i in range(n_deriv_types):
814 row_offsets[i] = cumsum
815 col_offsets[i] = cumsum
816 cumsum += index_sizes[i]
818 # Even/odd derivative transforms
819 der_indices_even = make_first_even(der_indices)
820 der_indices_odd = make_first_odd(der_indices)
821 der_indices_tr_even, der_ind_order_even = transform_der_indices(der_indices_even, der_map)
822 der_indices_tr_odd, der_ind_order_odd = transform_der_indices(der_indices_odd, der_map)
824 fd_flat_indices = np.array(der_indices_tr_even, dtype=np.int64)
825 df_flat_indices = np.array(der_indices_tr_odd, dtype=np.int64)
827 dd_flat_indices = np.empty((n_deriv_types, n_deriv_types), dtype=np.int64)
828 for i in range(n_deriv_types):
829 for j in range(n_deriv_types):
830 imdir1 = der_ind_order_even[j]
831 imdir2 = der_ind_order_odd[i]
832 new_idx, new_ord = dh.mult_dir(imdir1[0], imdir1[1], imdir2[0], imdir2[1])
833 dd_flat_indices[i, j] = der_map[new_ord][new_idx]
835 return {
836 'signs': np.ones(n_deriv_types + 1, dtype=np.float64), # unused, kept for API
837 'index_arrays': index_arrays,
838 'index_sizes': index_sizes,
839 'n_pts_with_derivs': n_pts_with_derivs,
840 'dd_flat_indices': dd_flat_indices,
841 'n_deriv_types': n_deriv_types,
842 'idx_flat': idx_flat,
843 'idx_offsets': idx_offsets,
844 'row_offsets': row_offsets,
845 'col_offsets': col_offsets,
846 'fd_flat_indices': fd_flat_indices,
847 'df_flat_indices': df_flat_indices,
848 }
851def rbf_kernel_fast(phi_exp_3d, plan, out=None):
852 """Fast kernel assembly using precomputed plan and fused numba kernel."""
853 n_rows_func = phi_exp_3d.shape[1]
854 n_cols_func = phi_exp_3d.shape[2]
855 total = n_rows_func + plan['n_pts_with_derivs']
856 if out is not None:
857 K = out
858 else:
859 K = np.empty((total, total))
861 if 'row_offsets_abs' in plan:
862 row_off = plan['row_offsets_abs']
863 col_off = plan['col_offsets_abs']
864 else:
865 row_off = plan['row_offsets'] + n_rows_func
866 col_off = plan['col_offsets'] + n_cols_func
868 _assemble_kernel_numba(
869 phi_exp_3d, K, n_rows_func, n_cols_func,
870 plan['fd_flat_indices'], plan['df_flat_indices'], plan['dd_flat_indices'],
871 plan['idx_flat'], plan['idx_offsets'], plan['index_sizes'],
872 plan['n_deriv_types'], row_off, col_off,
873 )
874 return K
877def rbf_kernel_predictions(
878 phi,
879 phi_exp,
880 n_order,
881 n_bases,
882 der_indices,
883 powers,
884 return_deriv,
885 index=-1,
886 common_derivs=None,
887 calc_cov=False,
888 powers_predict=None
889):
890 """
891 Constructs the RBF kernel matrix for predictions with selective derivative coverage.
893 This version uses Numba-accelerated functions for efficient matrix slicing.
895 Parameters
896 ----------
897 phi : OTI array
898 Base kernel matrix between test and training points.
899 phi_exp : ndarray
900 Expanded derivative array from phi.get_all_derivs().
901 n_order : int
902 Maximum derivative order.
903 n_bases : int
904 Number of OTI bases.
905 der_indices : list
906 Derivative specifications for training data.
907 powers : list of int
908 Sign powers (unused but kept for API consistency).
909 return_deriv : bool
910 If True, predict derivatives at test points.
911 index : list of list
912 Training point indices for each derivative type.
913 common_derivs : list
914 Common derivative indices to predict.
915 calc_cov : bool
916 If True, computing covariance.
917 powers_predict : list of int, optional
918 Sign powers for prediction derivatives (unused but kept for API consistency).
920 Returns
921 -------
922 K : ndarray
923 Prediction kernel matrix.
924 """
925 if calc_cov and not return_deriv:
926 return phi.real
928 dh = coti.get_dHelp()
930 n_train, n_test = phi.shape
931 n_deriv_types = len(der_indices)
932 n_deriv_types_pred = len(common_derivs) if common_derivs else 0
934 # Handle n_order = 0 case
935 if n_order == 0:
936 return phi.real.T
938 # Convert index lists to numpy arrays for numba
939 index_arrays = [np.asarray(locs, dtype=np.int64) for locs in index]
941 # Extract derivative components based on return_deriv
942 if not return_deriv:
943 phi_exp = phi.get_all_derivs(n_bases, n_order)
944 der_map = deriv_map(n_bases, n_order)
945 else:
946 phi_exp = phi.get_all_derivs(n_bases, 2 * n_order)
947 der_map = deriv_map(n_bases, 2 * n_order)
949 # Create derivative index transformations
950 der_indices_even = make_first_even(der_indices)
951 der_indices_odd = make_first_odd(der_indices)
952 der_indices_tr_odd, der_ind_order_odd = transform_der_indices(der_indices_odd, der_map)
953 der_indices_odd_pred = make_first_odd(common_derivs) if common_derivs else []
954 der_indices_tr_odd_pred, der_ind_order_odd_pred = transform_der_indices(der_indices_odd_pred, der_map) if common_derivs else ([], [])
956 # Compute matrix dimensions
957 n_rows_func = n_test
958 if return_deriv:
959 derivative_locations_test = [np.arange(n_test, dtype=np.int64)] * n_deriv_types_pred
960 n_rows_derivs = sum(len(locs) for locs in derivative_locations_test)
961 else:
962 n_rows_derivs = 0
963 total_rows = n_rows_func + n_rows_derivs
965 if return_deriv and calc_cov:
966 n_cols_func = n_train
967 n_deriv_types = n_deriv_types_pred
968 n_cols_derivs = sum(len(locs) for locs in derivative_locations_test)
969 total_cols = n_cols_func + n_cols_derivs
970 else:
971 n_cols_func = n_train
972 n_cols_derivs = sum(len(locs) for locs in index)
973 total_cols = n_cols_func + n_cols_derivs
975 # Compute block offsets
976 row_offsets = [0, n_test]
977 if return_deriv:
978 for i in range(n_deriv_types_pred):
979 row_offsets.append(row_offsets[-1] + len(derivative_locations_test[i]))
981 col_offsets = [0, n_train]
982 for i in range(n_deriv_types):
983 col_offsets.append(col_offsets[-1] + len(index[i]))
985 # Allocate output matrix
986 K = np.zeros((total_rows, total_cols))
987 base_shape = (n_train, n_test)
989 # Block (0,0): Function-Function (K_ff)
990 content_full = phi_exp[0].reshape(base_shape)
991 K[:n_test, :n_train] = content_full.T
993 # First Block-Row: Function-Derivative (K_fd)
994 for j in range(n_deriv_types):
995 train_locs = index_arrays[j]
996 col_start = col_offsets[j + 1]
998 flat_idx = der_indices_tr_odd_pred[j] if calc_cov else der_indices_tr_odd[j]
999 content_full = phi_exp[flat_idx].reshape(base_shape)
1001 # Use numba for efficient row extraction with transpose
1002 extract_rows_and_assign_transposed(content_full, train_locs, K,
1003 0, col_start, n_test)
1005 if not return_deriv:
1006 return K
1008 # First Block-Column: Derivative-Function (K_df)
1009 der_indices_tr_even, der_ind_order_even = transform_der_indices(der_indices_even, der_map)
1010 der_indices_even_pred = make_first_even(common_derivs)
1011 der_indices_tr_even_pred, der_ind_order_even_pred = transform_der_indices(der_indices_even_pred, der_map)
1013 for i in range(n_deriv_types_pred):
1014 test_locs = derivative_locations_test[i]
1015 row_start = row_offsets[i + 1]
1017 flat_idx = der_indices_tr_even_pred[i]
1018 content_full = phi_exp[flat_idx].reshape(base_shape)
1020 # Use numba for efficient column extraction with transpose
1021 extract_cols_and_assign_transposed(content_full, test_locs, K,
1022 row_start, 0, n_train)
1024 # Inner Blocks: Derivative-Derivative (K_dd)
1025 for i in range(n_deriv_types_pred):
1026 test_locs = derivative_locations_test[i]
1027 row_start = row_offsets[i + 1]
1029 for j in range(n_deriv_types):
1030 train_locs = index_arrays[j]
1031 col_start = col_offsets[j + 1]
1033 imdir_train = der_ind_order_odd_pred[j] if calc_cov else der_ind_order_odd[j]
1034 imdir_test = der_ind_order_even_pred[i]
1035 new_idx, new_ord = dh.mult_dir(
1036 imdir_train[0], imdir_train[1],
1037 imdir_test[0], imdir_test[1]
1038 )
1039 flat_idx = der_map[new_ord][new_idx]
1041 content_full = phi_exp[flat_idx].reshape(base_shape)
1043 # Use numba for efficient submatrix extraction with transpose (replaces np.ix_ + .T)
1044 extract_and_assign_transposed(content_full, train_locs, test_locs, K,
1045 row_start, col_start)
1047 return K
1050# =============================================================================
1051# Utility functions
1052# =============================================================================
1054def determine_weights(diffs_by_dim, diffs_test, length_scales, kernel_func, sigma_n):
1055 """
1056 Vectorized version: compute interpolation weights for multiple test points at once.
1058 Parameters
1059 ----------
1060 diffs_by_dim : list of ndarray
1061 Pairwise differences between training points (by dimension).
1062 diffs_test : list of ndarray
1063 Pairwise differences between test points and training points (by dimension).
1064 Shape: each array is (n_test, n_train) or similar batch dimension.
1065 length_scales : array-like
1066 Kernel hyperparameters.
1067 kernel_func : callable
1068 Kernel function.
1069 sigma_n : float
1070 Noise parameter (if needed).
1072 Returns
1073 -------
1074 weights_matrix : ndarray of shape (n_test, n_train)
1075 Interpolation weights for each test point.
1076 """
1077 # Compute K matrix (training covariance) - same for all test points
1078 K = kernel_func(diffs_by_dim, length_scales).real
1079 n_train = K.shape[0]
1081 # Compute r vectors (test-train covariances) for all test points at once
1082 r_all = kernel_func(diffs_test, length_scales).real
1083 n_test = r_all.shape[0]
1085 # Build augmented system matrix M (same for all test points)
1086 M = np.zeros((n_train + 1, n_train + 1))
1087 M[:n_train, :n_train] = K
1088 M[:n_train, n_train] = 1
1089 M[n_train, :n_train] = 1
1090 M[n_train, n_train] = 0
1092 # Build augmented RHS for all test points
1093 r_augmented = np.zeros((n_test, n_train + 1))
1094 r_augmented[:, :n_train] = r_all
1095 r_augmented[:, n_train] = 1
1097 # Solve for all test points at once
1098 solution = np.linalg.solve(M, r_augmented.T)
1100 # Extract weights (exclude Lagrange multiplier)
1101 weights_matrix = solution[:n_train, :].T
1103 return weights_matrix
1106def to_list(x):
1107 """Convert tuple to list recursively."""
1108 if isinstance(x, tuple):
1109 return [to_list(i) for i in x]
1110 return x
1113def to_tuple(item):
1114 """Convert list to tuple recursively."""
1115 if isinstance(item, list):
1116 return tuple(to_tuple(x) for x in item)
1117 return item
1120def find_common_derivatives(all_indices):
1121 """Find derivative indices common to all submodels."""
1122 sets = [set(to_tuple(elem) for elem in idx_list) for idx_list in all_indices]
1123 return sets[0].intersection(*sets[1:])