From b7e541e02f331619107273c22f2dc7813b2a52ed Mon Sep 17 00:00:00 2001
From: Hong Chen <hchen99@users.noreply.github.com>
Date: Thu, 14 Dec 2023 13:28:16 -0600
Subject: [PATCH] Updated not to add "friend class" to embedded private list so
 it can be accessed.  (#1625)

* Updated not to add "friend class" to embedded private list so it can be accessed.

* Added a list to hold all friends class decl so able to specifically make sure not to remove them from io source but not to allow is source for others.

* Added namespaces to class names for comparison as class name only needs to be unique within the same namespace.

* Added if NULL check before getting friend decl type string.
---
 .../Interface_Code_Gen/ClassVisitor.cpp       | 39 ++++++++++++++++++-
 .../Interface_Code_Gen/ClassVisitor.hh        |  1 +
 .../Interface_Code_Gen/ConstructValues.cpp    |  7 ++++
 .../Interface_Code_Gen/ConstructValues.hh     |  1 +
 4 files changed, 46 insertions(+), 2 deletions(-)

diff --git a/trick_source/codegen/Interface_Code_Gen/ClassVisitor.cpp b/trick_source/codegen/Interface_Code_Gen/ClassVisitor.cpp
index 1891a6ba..0f9d70bd 100644
--- a/trick_source/codegen/Interface_Code_Gen/ClassVisitor.cpp
+++ b/trick_source/codegen/Interface_Code_Gen/ClassVisitor.cpp
@@ -68,8 +68,21 @@ bool CXXRecordVisitor::TraverseDecl(clang::Decl *d) {
                     // test against them.
                     ClassValues temp_cv ;
                     temp_cv.getNamespacesAndClasses(crd->getDeclContext()) ;
-                    private_embedded_classes.insert(temp_cv.getFullyQualifiedName() + crd->getNameAsString()) ;
-                    //std::cout << "marking private " << temp_cv.getFullyQualifiedName() + crd->getNameAsString() << std::endl ;
+                    /*
+                    for (auto const &pec : private_embedded_classes) {
+                        std::cout << "===private_embedded_classes..." << pec << "===" << std::endl;
+                    }
+                    if (friend_classes.size() > 0) {
+                        for (auto const &fc : friend_classes) {
+                            std::cout << "===friend_classes..." << fc << "===" << std::endl;
+                        }
+                    }
+                    */
+                    std::string class_str = temp_cv.getFullyQualifiedName() + crd->getNameAsString();
+                    // Private embedded classes are not printed to io source unless they are friend classes.
+                    if (friend_classes.find(class_str) == friend_classes.end()) {
+                        private_embedded_classes.insert(class_str);
+                    }
                 }
             }
         }
@@ -104,6 +117,26 @@ bool CXXRecordVisitor::TraverseDecl(clang::Decl *d) {
         }
         break ;
         case clang::Decl::Friend : {
+            ClassValues temp_cv ;
+            temp_cv.getNamespacesAndClasses(d->getDeclContext()) ;
+            clang::FriendDecl * fd = static_cast<clang::FriendDecl *>(d) ;
+            std::string class_str;
+
+            // Only use namespaces for identifying class name as the class name can't be the same within the same namespace.
+            if (fd->getFriendDecl() != NULL) {
+                class_str = temp_cv.getNameOnlyWithNamespaces() + fd->getFriendDecl()->getNameAsString();
+            } 
+            // For friend class, only need to get type here but the above getting class_str is for just in case needed.
+            if (fd->getFriendType() != NULL) {
+                class_str = temp_cv.getNameOnlyWithNamespaces() + fd->getFriendType()->getType().getAsString();
+            }
+            size_t pos;
+            // Only save class name to the friend class list
+            if ((pos = class_str.find("class ")) != std::string::npos ) {
+                class_str.erase(pos , 6) ;
+                friend_classes.insert(class_str);
+            }
+
             TraverseFriendDecl(static_cast<clang::FriendDecl *>(d)) ;
         }
         break ;
@@ -331,6 +364,8 @@ ClassValues * CXXRecordVisitor::get_class_data() {
 
 std::set<std::string> CXXRecordVisitor::private_embedded_classes ;
 
+std::set<std::string> CXXRecordVisitor::friend_classes ;
+
 void CXXRecordVisitor::addPrivateEmbeddedClass( std::string in_name ) {
     private_embedded_classes.insert(in_name) ;
 }
diff --git a/trick_source/codegen/Interface_Code_Gen/ClassVisitor.hh b/trick_source/codegen/Interface_Code_Gen/ClassVisitor.hh
index 30676c39..2158bbb6 100644
--- a/trick_source/codegen/Interface_Code_Gen/ClassVisitor.hh
+++ b/trick_source/codegen/Interface_Code_Gen/ClassVisitor.hh
@@ -80,6 +80,7 @@ class CXXRecordVisitor : public clang::RecursiveASTVisitor<CXXRecordVisitor> {
         bool access_spec_found ;
 
         static std::set<std::string> private_embedded_classes ;
+        static std::set<std::string> friend_classes ;
 } ;
 
 #endif
diff --git a/trick_source/codegen/Interface_Code_Gen/ConstructValues.cpp b/trick_source/codegen/Interface_Code_Gen/ConstructValues.cpp
index 370779e5..f4d43f52 100644
--- a/trick_source/codegen/Interface_Code_Gen/ConstructValues.cpp
+++ b/trick_source/codegen/Interface_Code_Gen/ConstructValues.cpp
@@ -219,3 +219,10 @@ std::string ConstructValues::getFullyQualifiedTypeName(const std::string& delimi
     oss << name ;
     return oss.str() ;
 }
+
+std::string ConstructValues::getNameOnlyWithNamespaces(const std::string& delimiter) {
+    std::ostringstream oss ;
+    printNamespaces(oss, delimiter) ;
+    oss << name ;
+    return oss.str() ;
+}
diff --git a/trick_source/codegen/Interface_Code_Gen/ConstructValues.hh b/trick_source/codegen/Interface_Code_Gen/ConstructValues.hh
index c92d6436..0af58069 100644
--- a/trick_source/codegen/Interface_Code_Gen/ConstructValues.hh
+++ b/trick_source/codegen/Interface_Code_Gen/ConstructValues.hh
@@ -61,6 +61,7 @@ class ConstructValues {
         }
 
         std::string getFullyQualifiedName(const std::string& delimiter = "::") ;
+        std::string getNameOnlyWithNamespaces(const std::string& delimiter = "::");
 
         void printOpenNamespaceBlocks(std::ostream& ostream);
         void printCloseNamespaceBlocks(std::ostream& ostream);