From 507d20390c294466a22ad8722341e7510f28222c Mon Sep 17 00:00:00 2001
From: xinyangli <lixinyang411@gmail.com>
Date: Fri, 16 Aug 2024 12:27:24 +0800
Subject: [PATCH] fix: exit execution when first difftest failed

---
 .gdbinit             |  2 +-
 include/difftest.hpp | 13 +++++++------
 src/difftest.cpp     | 33 +++++++++++++++++++--------------
 src/main.cpp         | 11 ++++++++++-
 4 files changed, 37 insertions(+), 22 deletions(-)

diff --git a/.gdbinit b/.gdbinit
index c1022cc..0199ab7 100644
--- a/.gdbinit
+++ b/.gdbinit
@@ -1,4 +1,4 @@
-file /nix/store/ijxm784gr0sx5p4d92rlag0ippyd0mvm-am-kernel-riscv32-none-elf-2024-07-10/libexec/am-kernels/bench
+file /nix/store/ijxm784gr0sx5p4d92rlag0ippyd0mvm-am-kernel-riscv32-none-elf-2024-07-10/libexec/am-kernels/demo
 
 set substitute-path /build/am-kernels /home/xin/repo/ysyx-workbench/am-kernels
 set substitute-path /build/abstract-machine /home/xin/repo/ysyx-workbench/abstract-machine
diff --git a/include/difftest.hpp b/include/difftest.hpp
index 2f7b301..211a6ff 100644
--- a/include/difftest.hpp
+++ b/include/difftest.hpp
@@ -23,12 +23,6 @@ private:
     return __atomic_load_n(&halt_status, __ATOMIC_RELAXED);
   };
 
-  struct ExecRet {
-    bool at_breakpoint;
-    bool do_difftest;
-  };
-  ExecRet exec(size_t n, gdb_action_t *ret);
-
 public:
   Difftest(Target &&dut, std::vector<Target> &&refs);
 
@@ -44,6 +38,13 @@ public:
   bool set_bp(size_t addr, bp_type_t type);
   bool del_bp(size_t addr, bp_type_t type);
 
+  struct ExecRet {
+    bool at_breakpoint;
+    bool do_difftest;
+    bool check_failed;
+  };
+  ExecRet exec(size_t n, gdb_action_t *ret);
+
   bool check_all();
   int sync_regs_to_ref(void);
   std::string list_targets(void);
diff --git a/src/difftest.cpp b/src/difftest.cpp
index 0567ae3..8a09618 100644
--- a/src/difftest.cpp
+++ b/src/difftest.cpp
@@ -56,7 +56,8 @@ bool Difftest::check_all() {
 }
 
 Difftest::ExecRet Difftest::exec(size_t n, gdb_action_t *ret) {
-  ExecRet exec_ret = {.at_breakpoint = false, .do_difftest = true};
+  ExecRet exec_ret = {
+      .at_breakpoint = false, .do_difftest = true, .check_failed = false};
   while (n--) {
     Target *pbreak = &(*(this->begin()));
     // TODO: For improvement, use ThreadPool here for concurrent execution?
@@ -74,23 +75,35 @@ Difftest::ExecRet Difftest::exec(size_t n, gdb_action_t *ret) {
       exec_ret.do_difftest = *target.do_difftest && exec_ret.do_difftest;
     }
 
+    // Do difftest, or sync registers to ref
+    if (exec_ret.do_difftest) {
+      if (!check_all()) {
+        exec_ret.check_failed = true;
+      }
+    } else {
+      size_t pc = 0;
+      read_reg(32, &pc);
+      spdlog::debug("Difftest skipped at {}", (void *)pc);
+      sync_regs_to_ref();
+    }
+
+    if (exec_ret.check_failed) {
+      ret->reason = gdb_action_t::ACT_SHUTDOWN;
+    }
+
     if (exec_ret.at_breakpoint) {
       ret->reason = pbreak->last_res.reason;
       ret->data = pbreak->last_res.data;
       break;
     }
   }
+
   return exec_ret;
 }
 
 gdb_action_t Difftest::stepi() {
   gdb_action_t ret = {.reason = gdb_action_t::ACT_NONE};
   ExecRet exec_result = exec(1, &ret);
-  if (exec_result.do_difftest) {
-    check_all();
-  } else {
-    sync_regs_to_ref();
-  }
   return ret;
 }
 
@@ -100,14 +113,6 @@ gdb_action_t Difftest::cont() {
   start_run();
   while (!is_halt()) {
     exec_ret = exec(1, &ret);
-    if (exec_ret.do_difftest) {
-      check_all();
-    } else {
-      size_t pc = 0;
-      read_reg(32, &pc);
-      spdlog::debug("Difftest skipped at {}", (void *)pc);
-      sync_regs_to_ref();
-    }
     if (exec_ret.at_breakpoint)
       break;
   };
diff --git a/src/main.cpp b/src/main.cpp
index c2b3bf3..6dea91f 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -38,7 +38,16 @@ int main(int argc, char **argv) {
   if (config.use_debugger) {
     gdbstub_loop(&difftest, config.gdbstub_addr);
   } else {
-    difftest.cont();
+    gdb_action_t ret = {.reason = gdb_action_t::ACT_NONE};
+    Difftest::ExecRet exec_ret;
+    while (1) {
+      exec_ret = difftest.exec(1, &ret);
+      if (exec_ret.check_failed)
+        return 1;
+      if (exec_ret.at_breakpoint)
+        break;
+    };
+    return 0;
   }
 
   return 0;