You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
2.3 KiB
61 lines
2.3 KiB
import numpy as np
|
|
import os
|
|
from pathlib import Path
|
|
import matplotlib.pyplot as plt
|
|
|
|
def plot_comparison(base_path, iter_a, iter_b, start_frame, end_frame, output_dir, model="awrl1432"):
|
|
path_a = Path(base_path) / "iterations" / iter_a / model
|
|
path_b = Path(base_path) / "iterations" / iter_b / model
|
|
out_path = Path(output_dir)
|
|
out_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
if not path_a.exists() or not path_b.exists():
|
|
print(f"[ERROR] Paths not found. A: {path_a}, B: {path_b}")
|
|
return
|
|
|
|
frames = [f"frame_{i:06d}.npy" for i in range(start_frame, end_frame + 1)]
|
|
|
|
for frame_name in frames:
|
|
f_a = path_a / frame_name
|
|
f_b = path_b / frame_name
|
|
|
|
if not f_a.exists() or not f_b.exists():
|
|
print(f"[SKIP] {frame_name} missing.")
|
|
continue
|
|
|
|
data_a = np.load(f_a)
|
|
data_b = np.load(f_b)
|
|
|
|
# Columns: [x, y, z, velocity, magnitude]
|
|
# CARLA: X is Forward, Y is Side
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
|
|
|
|
# Iteration A
|
|
sc1 = ax1.scatter(data_a[:, 1], data_a[:, 0], c=data_a[:, 4], cmap='viridis', s=10, vmin=0, vmax=100)
|
|
ax1.set_title(f"{iter_a} - {frame_name}")
|
|
ax1.set_xlabel("Side (Y)")
|
|
ax1.set_ylabel("Forward (X)")
|
|
ax1.set_xlim([-30, 30])
|
|
ax1.set_ylim([0, 100])
|
|
plt.colorbar(sc1, ax=ax1, label='Magnitude')
|
|
|
|
# Iteration B
|
|
sc2 = ax2.scatter(data_b[:, 1], data_b[:, 0], c=data_b[:, 4], cmap='viridis', s=10, vmin=0, vmax=100)
|
|
ax2.set_title(f"{iter_b} - {frame_name}")
|
|
ax2.set_xlabel("Side (Y)")
|
|
ax2.set_ylabel("Forward (X)")
|
|
ax2.set_xlim([-30, 30])
|
|
ax2.set_ylim([0, 100])
|
|
plt.colorbar(sc2, ax=ax2, label='Magnitude')
|
|
|
|
layout = plt.tight_layout()
|
|
save_name = out_path / f"compare_{frame_name.replace('.npy', '.png')}"
|
|
plt.savefig(save_name)
|
|
plt.close()
|
|
print(f"[DONE] Saved {save_name.name}")
|
|
|
|
if __name__ == "__main__":
|
|
base = "Shenron_debug"
|
|
# IMPORTANT: Use absolute path for artifacts in conversation dir
|
|
artifacts_dir = r"C:\Users\rakadu1\.gemini\antigravity\brain\67913a3c-cbc2-4fba-87e3-88fbea20f043\artifacts"
|
|
plot_comparison(base, "14b_normalization", "14b_stress_test", 200, 215, artifacts_dir)
|