Hello all,
I have a problem with my python code which is that it takes forever to run (more than 2 weeks of running). My Python script takes a particle tracking NetCDF file for a single year, and a shapefile of coral reef polygons, then builds a “connectivity” transition matrix that counts how many particles travel from each source polygon to each sink polygon each day. It also applies a probabilistic “settlement competency” filter based on a Randall survival curve from a csv file.
The issue comes from looping over the more than a million particles at each timestep and checking if its position is contained in any of my 207 reef polygons as defined by my shapefile. Any recommendations which could significantly speed up my code would be super helpful. I have included the full code at the end of this post and also as a .py file attachment if that helps. Thank you very much :)
This is the portion of the code which slows it down a lot:
This script begins by loading
1. NetCDF file containing hourly longitude and latitude for over a million particles across 30 days (with -999 marking beached or out of domain points).
2. A shapefile containing 207 polygons of which some are multipart (and complex polygons) defining coral reef patches. These polygons form the rows and columns in the outputted connectivity (transitions) matrix.
3. CSV of a Randall competency curve (probability of settlement as a function of larval age in days).
After interpolating that curve into a continuous function, the code reads the shapefile into a GeoDataFrame, builds an R tree index for rapid point in polygon lookups, and—for each year’s NetCDF—determines the last successfully processed day so it can resume without overwriting previous results.
Within the process_file function, the script opens the NetCDF, masks all -999 values to NaN (these are the beached particles), and discards particles whose initial positions are invalid.
Also includes an optional debug to reduce particle sample size. It then assigns each surviving particle to a “source” polygon by creating a Shapely Point at time 0 and querying the R tree. Next, it steps through each hourly timestamp from the resume point up to 30 days: for each particle that still has valid coordinates, it draws a random number and, if that number falls below the interpolated Randall probability for the particle’s age in days, it again queries the R tree and checks point in polygon membership to find the “sink” polygon.
Each source–sink transition is tallied in a 207×207 matrix for that hour whereby rows in the matrix are source reefs and columns are sink reefs. After every 24 hours, the code stacks those hourly matries into a 207×207×24 array and outputs the results.
Note: I have a seperate code for each yearly run so i can run them each individually instead of sequentially.
This is the full code, thanks for the help!:
I have a problem with my python code which is that it takes forever to run (more than 2 weeks of running). My Python script takes a particle tracking NetCDF file for a single year, and a shapefile of coral reef polygons, then builds a “connectivity” transition matrix that counts how many particles travel from each source polygon to each sink polygon each day. It also applies a probabilistic “settlement competency” filter based on a Randall survival curve from a csv file.
The issue comes from looping over the more than a million particles at each timestep and checking if its position is contained in any of my 207 reef polygons as defined by my shapefile. Any recommendations which could significantly speed up my code would be super helpful. I have included the full code at the end of this post and also as a .py file attachment if that helps. Thank you very much :)
This is the portion of the code which slows it down a lot:
# Initialize the 3D numpy array to hold the transition matrices for each day all_daily_matrices = [] # Loop over each timestep -- Section of the code which takes the longest for t in range(int(Tmin), num_days * 24): if t >= num_timesteps: break if t % 24 == 0: print(f"Processing day {t // 24}", flush=True) # Initialize a new transitions matrix for this timestep transitions_matrix = np.zeros((num_polygons, num_polygons), dtype=int) # Loop over each particle for i in range(num_particles): point_current = Point(lon[i, t], lat[i, t]) if np.isnan(lon[i, t]) or np.isnan(lat[i, t]): # if lon lat is not nan continue, otherwise it's beached or out of domain continue probability = P_t_interpolator(t / 24) # Check probability of settlement if np.random.rand() <= probability: for k in idx.intersection(point_current.bounds): if point_current.within(gdf_all.geometry[k]): transitions_matrix[initial_polygons[i], k] += 1 break # exit once the sink polygon is foundBrief description of what the code does:
This script begins by loading
1. NetCDF file containing hourly longitude and latitude for over a million particles across 30 days (with -999 marking beached or out of domain points).
2. A shapefile containing 207 polygons of which some are multipart (and complex polygons) defining coral reef patches. These polygons form the rows and columns in the outputted connectivity (transitions) matrix.
3. CSV of a Randall competency curve (probability of settlement as a function of larval age in days).
After interpolating that curve into a continuous function, the code reads the shapefile into a GeoDataFrame, builds an R tree index for rapid point in polygon lookups, and—for each year’s NetCDF—determines the last successfully processed day so it can resume without overwriting previous results.
Within the process_file function, the script opens the NetCDF, masks all -999 values to NaN (these are the beached particles), and discards particles whose initial positions are invalid.
Also includes an optional debug to reduce particle sample size. It then assigns each surviving particle to a “source” polygon by creating a Shapely Point at time 0 and querying the R tree. Next, it steps through each hourly timestamp from the resume point up to 30 days: for each particle that still has valid coordinates, it draws a random number and, if that number falls below the interpolated Randall probability for the particle’s age in days, it again queries the R tree and checks point in polygon membership to find the “sink” polygon.
Each source–sink transition is tallied in a 207×207 matrix for that hour whereby rows in the matrix are source reefs and columns are sink reefs. After every 24 hours, the code stacks those hourly matries into a 207×207×24 array and outputs the results.
Note: I have a seperate code for each yearly run so i can run them each individually instead of sequentially.
This is the full code, thanks for the help!:
from rtree import index import numpy as np import time import glob from shapely.geometry import Polygon, Point import csv import pandas as pd import xarray as xr import geopandas as gpd import os from scipy.interpolate import interp1d import random # Load in Randall curve curve_data = pd.read_csv('/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/Randall_comp_curves/Amil_cca_curve_SB.csv') larval_age = curve_data['LarvalAge'].values P_t = curve_data['P_t'].values P_t_interpolator = interp1d(larval_age, P_t, fill_value="extrapolate") def process_file(file_path, gdf_all, idx, Tmin, num_days=30, num_hours_per_day=24): print(f"Starting from Tmin Yobama {Tmin}", flush=True) data_xarray = xr.open_dataset(file_path, mode='r') lon = data_xarray['lon'].values # i have also tried .astype(np.float32) but not significantly faster lat = data_xarray['lat'].values lon[lon == -999] = np.nan # set beached particles to NAN lat[lat == -999] = np.nan # set beached particles to NAN print(f"NaN count at t=0: {np.isnan(lon[:, 0]).sum()}, {np.isnan(lat[:, 0]).sum()}", flush=True) print(f"-999 count at t=0: {(lon[:, 0] == -999).sum()}, {(lat[:, 0] == -999).sum()}", flush=True) num_particles, num_timesteps = lon.shape valid_particles = ~np.isnan(lon[:, 0]) & ~np.isnan(lat[:, 0]) & (lon[:, 0] != -999) & (lat[:, 0] != -999) lon = lon[valid_particles, :] lat = lat[valid_particles, :] num_particles, num_timesteps = lon.shape ############################################################################## # Select only 10000 particles for debugging num_particles_to_debug = 10000 if num_particles > num_particles_to_debug: sampled_indices = np.random.choice(num_particles, num_particles_to_debug, replace=False) lon = lon[sampled_indices, :] lat = lat[sampled_indices, :] print(f"Debugging with {lon.shape[0]} particles", flush=True) num_particles, num_timesteps = lon.shape print(f"num of particles {num_particles}", flush=True) ########################################################################### num_polygons = len(gdf_all) transitions_matrix = np.zeros((num_polygons, num_polygons), dtype=int) gdf_all = gdf_all.sort_values(by='id') print(f"Finding initial polys", flush=True) initial_polygons = [] for i in range(num_particles): point_T0 = Point(lon[i, 0], lat[i, 0]) initial_polygon = next((j for j in idx.intersection(point_T0.bounds) if point_T0.intersects(gdf_all.geometry[j])), None) initial_polygons.append(initial_polygon) print(f"Valid particles at t=0 w initial polygons: {num_particles}", flush=True) # Initialize the 3D numpy array to hold the transition matrices for each day all_daily_matrices = [] # Loop over each timestep -- Section of the code which takes the longest for t in range(int(Tmin), num_days * 24): if t >= num_timesteps: break if t % 24 == 0: print(f"Processing day {t // 24}", flush=True) # Initialize a new transitions matrix for this timestep transitions_matrix = np.zeros((num_polygons, num_polygons), dtype=int) # Loop over each particle for i in range(num_particles): point_current = Point(lon[i, t], lat[i, t]) if np.isnan(lon[i, t]) or np.isnan(lat[i, t]): # if lon lat is not nan continue, otherwise it's beached or out of domain continue probability = P_t_interpolator(t / 24) # Check probability of settlement if np.random.rand() <= probability: for k in idx.intersection(point_current.bounds): if point_current.within(gdf_all.geometry[k]): transitions_matrix[initial_polygons[i], k] += 1 break # exit once the sink polygon is found # Check if the transitions_matrix has a consistent shape print(f"Shape of transitions_matrix for day {t//24}: {transitions_matrix.shape}") # Append the transitions matrix for this timestep to the daily 3D array all_daily_matrices.append(transitions_matrix) # Once all 24 timesteps for a day are processed, combine into a 3D array and output if (t + 1) % 24 == 0: # End of the day daily_array = np.stack(all_daily_matrices, axis=2) print(f"Shape of daily_array for day {t//24}: {daily_array.shape}") # Save the daily 3D numpy array daily_filename = f"{pathout}/transitions_matrix_day{t//24}_{year}.npy" np.save(daily_filename, daily_array) print(f"Saved daily transitions matrix for day {t//24} to {daily_filename}", flush=True) # Reset for the next day all_daily_matrices = [] return transitions_matrix, num_timesteps # Get most recent Tmin so can restart the run from the last outputted day def get_most_recent_Tmin(pathout, year): # Find all the transition matrix files for the given year file_pattern = os.path.join(pathout, f'transitions_matrix_day*_{year}.npy') files = glob.glob(file_pattern) if not files: return 0 # If no files are found, start from day 0 # Sort the files based on the day extracted from the filename files.sort(key=lambda x: int(x.split('day')[1].split('_')[0])) # Extract day number and sort # Get the most recent file and determine the day (Tmin) most_recent_file = files[-1] most_recent_day = int(most_recent_file.split('day')[1].split('_')[0]) # Set Tmin as the most recent day + 1 (to restart from the next day) Tmin = (most_recent_day + 1)*24 print(f"Most recent file: {most_recent_file}, Tmin set to: {Tmin}") return Tmin # Main to run the code if __name__ == "__main__": # Path to netcdf particle tracking files base_path = '/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/SensitivityAnalysis/pout/' pathout = '/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/conmat/' # just running for a single year (i have 30 years to run!) years = range(1996,1997) # Reef polygon shapefile gdf_all = gpd.read_file("/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/Coral_communities/Regions_w_geom_FINAL_crop.shp") gdf_all = gdf_all.sort_values(by='id') idx = index.Index((j, geom.bounds, None) for j, geom in enumerate(gdf_all.geometry)) # Process each file sequentially for year in years: Tmin = get_most_recent_Tmin(pathout, year) print(f"Using Tmin: {Tmin} for year {year}", flush = True) file_path = f'{base_path}/ParcelsOut_Diffusion_{year}.nc' transitions_matrix, num_timesteps = process_file(file_path, gdf_all, idx, Tmin=Tmin) # Capture both outputs print(f"All transition matrices saved individually to {pathout}.")
Attached Files