diff --git a/rpd_tracer/Logger.cpp b/rpd_tracer/Logger.cpp index 0d03a24..851f4be 100644 --- a/rpd_tracer/Logger.cpp +++ b/rpd_tracer/Logger.cpp @@ -63,6 +63,20 @@ void rpdflush() Logger::singleton().rpdflush(); } +void rpdFinalize() +{ + Logger::singleton().rpdFinalize(); +} +void rpdInit() +{ + Logger::singleton().rpdInit(); +} + +void rpdResetFinalize() +{ + Logger::singleton().rpdResetFinalize(); +} + void rpd_rangePush(const char *domain, const char *apiName, const char* args) { Logger::singleton().rpd_rangePush(domain, apiName, args); @@ -95,6 +109,10 @@ void Logger::rpdInit() { void Logger::rpdFinalize() { Logger::singleton().finalize(); } +void Logger::rpdResetFinalize() +{ + Logger::singleton().set_finalize_true(); +} void Logger::rpdstart() @@ -249,6 +267,12 @@ void Logger::init() static bool doFinalize = true; std::mutex finalizeMutex; +void Logger::set_finalize_true() +{ + std::lock_guard guard(finalizeMutex); + doFinalize = true; +} + void Logger::finalize() { std::lock_guard guard(finalizeMutex); diff --git a/rpd_tracer/Logger.h b/rpd_tracer/Logger.h index 07a6d21..8d6e966 100644 --- a/rpd_tracer/Logger.h +++ b/rpd_tracer/Logger.h @@ -48,6 +48,7 @@ class Logger void rpdstart(); void rpdstop(); void rpdflush(); + void rpdResetFinalize(); // External maker api void rpd_rangePush(const char *domain, const char *apiName, const char* args); @@ -80,6 +81,7 @@ class Logger void init(); void finalize(); + void set_finalize_true(); std::string m_filename; bool m_writeOverheadRecords {true}; diff --git a/rpd_tracer/rpdTracerControl.py b/rpd_tracer/rpdTracerControl.py index 9bbdc39..6e0ea0d 100644 --- a/rpd_tracer/rpdTracerControl.py +++ b/rpd_tracer/rpdTracerControl.py @@ -65,7 +65,7 @@ def initializeFile(cls): connection.close() # os.environ["RPDT_FILENAME"] = cls.__filename - cls.__initFile = False + cls.__initFile = False # You can set the output filename and optionally append to an exiting file. @@ -87,9 +87,26 @@ def setFilename(cls, name, append = False): os.environ["RPDT_FILENAME"] = cls.__filename cls.__initFile = False - def __init__(self): + @classmethod + def rpdReset(cls): + rpdTracerControl.__rpd.rpdResetFinalize() + rpdTracerControl.__rpd.rpdFinalize() + #cls.__initFile = True + + def __init__(self, file_name = None, nvtx = False): + if file_name != None: + # Force reset filename, if we are getting called in a loop. + rpdTracerControl.__initFile = True + rpdTracerControl.__filename = file_name rpdTracerControl.initializeFile() rpdTracerControl.loadLibrary() + self.nvtx = None + if nvtx: + import torch + self.nvtx = torch.autograd.profiler.emit_nvtx() + if file_name != None: + # Reinit for the new file + rpdTracerControl.__rpd.rpdInit() def __del__(self): pass @@ -105,9 +122,46 @@ def flush(self): def __enter__(self): self.start() + if self.nvtx: + self.nvtx.__enter__() + return self def __exit__(self, exc_type, exc_val, exc_tb): + if self.nvtx: + self.nvtx.__exit__(exc_type, exc_val, exc_tb) + if exc_type != None: + #Propagate exception + return False self.stop() + self.flush() + rpdTracerControl.rpdReset() + + def top_totals(self): + try: + conn = sqlite3.connect(rpdTracerControl.__filename) + cursor = conn.cursor() + cursor.execute("SELECT Name, TotalCalls, TotalDuration, Ave, Percentage FROM top;") + rows = cursor.fetchall() + + if rows: + from prettytable import PrettyTable + import textwrap + table = PrettyTable() + table.field_names = ["Name", "TotalCalls", "TotalDuration", "Ave", "Percentage"] + table.align = "l" + + for row in rows: + wrapped_name = '\n'.join(textwrap.wrap(row[0], 60)) + table.add_row([wrapped_name] + list(row[1:])) + + print(table) + else: + print("No data found in 'top' table.") + + except sqlite3.Error as e: + print(f"Error querying database: {e}") + finally: + conn.close() def rangePush(self, domain: str, apiName: str, args: str): rpdTracerControl.__rpd.rpd_rangePush(bytes(domain, encoding='utf-8'), bytes(apiName, encoding='utf-8'), bytes(args, encoding='utf-8'))