diff --git a/go/custom-set/cmd/cmd b/go/custom-set/cmd/cmd index 8bd1f22..23eb61c 100755 Binary files a/go/custom-set/cmd/cmd and b/go/custom-set/cmd/cmd differ diff --git a/go/custom-set/cmd/main.go b/go/custom-set/cmd/main.go index b993aab..88c867e 100644 --- a/go/custom-set/cmd/main.go +++ b/go/custom-set/cmd/main.go @@ -8,42 +8,22 @@ import ( func main() { fmt.Println("Creating Set 1") - s1 := stringset.NewFromSlice([]string{"a", "b"}) - addAndOutput(s1, "A") + s1 := stringset.NewFromSlice([]string{"a"}) + addAndOutput(s1, "b") + addAndOutput(s1, "d") + s1.PrettyPrint() + return addAndOutput(s1, "c") - addAndOutput(s1, "B") - addAndOutput(s1, "B1") - addAndOutput(s1, "A1") - addAndOutput(s1, "B2") - addAndOutput(s1, "A2") - addAndOutput(s1, "B3") - addAndOutput(s1, "A3") - delAndOutput(s1, "a") - - fmt.Println("Creating Set 2") - s2 := stringset.NewFromSlice([]string{"A", "c"}) - addAndOutput(s2, "a") - addAndOutput(s2, "b") - addAndOutput(s2, "B1") - addAndOutput(s2, "A3") - addAndOutput(s2, "A2") - addAndOutput(s2, "B2") - addAndOutput(s2, "A1") - addAndOutput(s2, "B3") - addAndOutput(s2, "B") - s2.PrettyPrint() - delAndOutput(s2, "a") - s2.PrettyPrint() + addAndOutput(s1, "0") + addAndOutput(s1, "aa") + addAndOutput(s1, "aaa") } -func addAndOutput(s *stringset.Set, val string) { - fmt.Println("Adding " + val) +func addAndOutput(s stringset.Set, val string) { + fmt.Println("Adding new value: " + val) s.Add(val) - fmt.Println(s.String()) } -func delAndOutput(s *stringset.Set, val string) { - fmt.Println("Deleting " + val) +func delAndOutput(s stringset.Set, val string) { s.Delete(val) - fmt.Println(s.String()) } diff --git a/go/custom-set/custom_set_test.go b/go/custom-set/custom_set_test.go index 7b8d0ba..c2cacb5 100644 --- a/go/custom-set/custom_set_test.go +++ b/go/custom-set/custom_set_test.go @@ -27,6 +27,7 @@ package stringset // Format the empty set as {}. import ( + "fmt" "math/rand" "reflect" "strconv" @@ -166,13 +167,17 @@ func TestEqual(t *testing.T) { // helper for testing Add, Delete func testEleOp(name string, op func(Set, string), cases []eleOpCase, t *testing.T) { for _, tc := range cases { + fmt.Print("Running Test Case: ") + fmt.Println(tc) s := NewFromSlice(tc.set) op(s, tc.ele) want := NewFromSlice(tc.want) if !Equal(s, want) { + fmt.Println(s.String()) t.Fatalf("%v %s %q = %v, want %v", NewFromSlice(tc.set), name, tc.ele, s, want) } + fmt.Println("=== Done ===") } } diff --git a/go/custom-set/stringset.go b/go/custom-set/stringset.go index b126270..f99a4c6 100644 --- a/go/custom-set/stringset.go +++ b/go/custom-set/stringset.go @@ -3,6 +3,7 @@ package stringset import ( "errors" "fmt" + "strconv" "strings" ) @@ -11,30 +12,30 @@ const testVersion = 3 // Set is a slice of strings that you can do set operations on. // I decided that I wanted to implement a binary tree for the storage type Set struct { - top SetValue + top *Node } -type SetValue struct { +type Node struct { value string - left *SetValue - right *SetValue + left *Node + right *Node } // New returns an empty Set func New() Set { s := new(Set) - s.top = SetValue{} + s.top = &Node{} return *s } // NewFromSlice takes a slice of strings and returns a Set func NewFromSlice(s []string) Set { ret := New() + r := &ret for i := range s { - ret.Add(s[i]) + r.Add(s[i]) } - ret.balance() - return ret + return *r } // Add adds a value to the set @@ -43,71 +44,32 @@ func (s Set) Add(v string) { s.top.value = v return } - s.top.Add(v) -} + if !s.Has(v) { + bef := "Current Top Node: " + s.top.value + ";L:" + s.top.getLeftValue() + ";R:" + s.top.getRightValue() + newNode := s.top.Add(&Node{value: v}) + fmt.Println(bef) + fmt.Println("Current New Node: " + newNode.value + ";L:" + newNode.getLeftValue() + ";R:" + newNode.getRightValue()) + fmt.Printf("newNode.left: %v (%v)\n", &newNode.left, newNode.left) + s.top.value = newNode.value + s.top.left = newNode.left + s.top.right = newNode.right + fmt.Println("Current Top Node: " + s.top.value + ";L:" + s.top.getLeftValue() + ";R:" + s.top.getRightValue()) + fmt.Printf("s.top.left: %v (%v)\n", &s.top.left, s.top.left) + fmt.Println("Add Done") -func (s *Set) PrettyPrint() { - s.pp(&s.top, 0) -} - -func (s *Set) pp(n *SetValue, indent int) { - if n != nil { - if n.left != nil { - s.pp(n.left, indent+4) - } - if n.right != nil { - s.pp(n.right, indent+4) - } - for ; indent > 0; indent-- { - fmt.Print(" ") - } - fmt.Println(n.value) + //s.top = s.top.Add(newNode) + //fmt.Println("Current Top Node: " + s.top.value + ";L:" + s.top.getLeftValue() + ";R:" + s.top.getRightValue()) } } -// balance balances the binary tree -func (s *Set) balance() { - -} - -func (s *Set) find(v string) *SetValue { - if s.top.value != "" { - sv, _ := s.top.find(v) - return sv - } - return nil -} - -// findWithParent finds a node with a child value v -// a return of nil, nil means it's the top node -func (s *Set) findParent(v string) (*SetValue, error) { - if s.top.value == v { - // no parent, it's the top. - return nil, nil - } else { - return s.top.findParent(v) - } - return nil, errors.New("Empty Set") -} - -func (s *Set) findHome(v *SetValue) { - if s.top.value == "" { - s.top.value = v.value - s.top.left = v.left - s.top.right = v.right - return - } - s.top.findHome(v) -} - // Delete removes the given value from the set func (s Set) Delete(v string) { if sv, err := s.findParent(v); err == nil { var cmp int - var delNode, repNode, orphan *SetValue + var delNode, repNode, orphan *Node if sv == nil { // Deleting 'top' - delNode = &s.top + delNode = s.top if delNode.left != nil { repNode = delNode.left orphan = delNode.right @@ -118,15 +80,13 @@ func (s Set) Delete(v string) { // No node to replace it with, we're done return } - s.top = *repNode + s.top = repNode } else { cmp = strings.Compare(v, sv.value) if cmp < 0 && sv.left != nil { - fmt.Println(" Left: " + sv.left.value) // It's the left node delNode = sv.left } else if cmp > 0 && sv.right != nil { - fmt.Println(" Right: " + sv.right.value) // It's the right node delNode = sv.right } @@ -157,6 +117,62 @@ func (s Set) Delete(v string) { } } +func (s Set) PrettyPrint() { + fmt.Println("=========") + s.pp(s.top, 0) + fmt.Println("=========") +} + +func (s Set) pp(n *Node, indent int) { + if n.left == n || n.right == n { + // Circular reference... + fmt.Println("Circular reference... You done messed up.") + return + } + if n != nil { + if n.left != nil { + s.pp(n.left, indent+4) + } + if n.right != nil { + s.pp(n.right, indent+4) + } + for ; indent > 0; indent-- { + fmt.Print(" ") + } + fmt.Println(n.value) + } +} + +func (s *Set) find(v string) *Node { + if s.top.value != "" { + sv, _ := s.top.find(v) + return sv + } + return nil +} + +// findWithParent finds a node with a child value v +// a return of nil, nil means it's the top node +func (s *Set) findParent(v string) (*Node, error) { + if s.top.value == v { + // no parent, it's the top. + return nil, nil + } else { + return s.top.findParent(v) + } + return nil, errors.New("Empty Set") +} + +func (s *Set) findHome(v *Node) { + if s.top.value == "" { + s.top.value = v.value + s.top.left = v.left + s.top.right = v.right + return + } + s.top.findHome(v) +} + // Has returns if the set contains the given value. func (s *Set) Has(v string) bool { return s.find(v) != nil @@ -195,7 +211,7 @@ func (s Set) String() string { // find looks for a node with value val, it either returns the node // or an error stating it couldn't find it. -func (sv *SetValue) find(val string) (*SetValue, error) { +func (sv *Node) find(val string) (*Node, error) { if sv.value == val { return sv, nil } @@ -211,7 +227,7 @@ func (sv *SetValue) find(val string) (*SetValue, error) { // findParent looks for the parent of the node with value val // If nil, nil is returned, it _is_ this node. -func (sv *SetValue) findParent(val string) (*SetValue, error) { +func (sv *Node) findParent(val string) (*Node, error) { if sv.value == val { // This should only trigger if this is the top node of the tree return nil, nil @@ -232,7 +248,7 @@ func (sv *SetValue) findParent(val string) (*SetValue, error) { return nil, errors.New("Value not found") } -func (sv *SetValue) findHome(v *SetValue) { +func (sv *Node) findHome(v *Node) { cmp := strings.Compare(v.value, sv.value) if cmp < 0 { if sv.left == nil { @@ -257,25 +273,86 @@ func (sv *SetValue) findHome(v *SetValue) { } } -func (sv *SetValue) Add(v string) { - cmp := strings.Compare(v, sv.value) +func (sv *Node) Add(n *Node) *Node { + fmt.Println("1 Adding at Node " + sv.value) + cmp := strings.Compare(n.value, sv.value) if cmp < 0 { + fmt.Println("2 checking to left") if sv.left == nil { - sv.left = &SetValue{value: v} + fmt.Println("3 left is nil, add it") + sv.left = n } else { - sv.left.Add(v) + fmt.Println("--> 4 recurse left") + sv.left = sv.left.Add(n) + fmt.Println("<-- 5 added to left") } - } else if cmp > 0 { + } else { + fmt.Println("6 checking to right") if sv.right == nil { - sv.right = &SetValue{value: v} + fmt.Println("7 right is nil, add it") + sv.right = n } else { - sv.right.Add(v) + fmt.Println("--> 8 recurse right") + sv.right = sv.right.Add(n) + fmt.Println("<-- 9 added to right") } } + + fmt.Println("10 let's balance now (sv: " + sv.value + ")") + balance := GetBalance(sv) + // left left unbalance + fmt.Println("11 balance: " + strconv.Itoa(balance)) + newVal := sv + if balance > 1 && strings.Compare(n.value, sv.getLeftValue()) < 0 { // < sv.left.value { + fmt.Println("12 Left Left") + newVal = RightRotate(sv) + fmt.Println("13") + } + // right right unbalance + if balance < -1 && strings.Compare(n.value, sv.getRightValue()) > 0 { //n.value > sv.right.value { + fmt.Println("14 Right Right") + newVal = LeftRotate(sv) + fmt.Println("15") + } + // left right unbalance + if balance > 1 && strings.Compare(n.value, sv.getLeftValue()) > 0 { //n.value > sv.left.value { + fmt.Println("16 Left Right") + sv.left = LeftRotate(sv.left) + fmt.Println("17") + newVal = RightRotate(sv) + fmt.Println("18") + } + // right left unbalance + if balance < -1 && strings.Compare(n.value, sv.getRightValue()) < 0 { //n.value < sv.right.value { + fmt.Println("19 Right Left") + sv.right = RightRotate(sv.right) + fmt.Println("20") + newVal = LeftRotate(sv) + fmt.Println("21") + } + fmt.Println("22 end add") + return newVal +} + +func (sv *Node) getLeftValue() string { + if sv.left == nil { + return "" + } + return sv.left.value +} + +func (sv *Node) getRightValue() string { + if sv.right == nil { + return "" + } + return sv.right.value } // Len returns how many elements are in the branches -func (sv *SetValue) Len() int { +func (sv *Node) Len() int { + if sv == nil { + return 0 + } ret := 1 if sv.left != nil { ret += sv.left.Len() @@ -287,26 +364,34 @@ func (sv *SetValue) Len() int { } // Has checks if this branch contains the value v -func (sv *SetValue) Has(v string) bool { +func (sv *Node) Has(v string) bool { ret, _ := sv.find(v) return ret != nil } // String gets a string value of this branch -func (sv *SetValue) String() string { +func (sv *Node) String() string { var ret string if sv.left != nil { - ret += sv.left.String() + ", " + if sv.left == sv { + ret += "**, " + } else { + ret += sv.left.String() + ", " + } } ret += "\"" + sv.value + "\"" if sv.right != nil { - ret += ", " + sv.right.String() + if sv.right == sv { + ret += "**, " + } else { + ret += ", " + sv.right.String() + } } return ret } // Slice returns a string slice of all values in the branch -func (sv *SetValue) Slice() []string { +func (sv *Node) Slice() []string { var ret []string if sv.left != nil { ret = sv.left.Slice() @@ -325,9 +410,6 @@ func Equal(s1, s2 Set) bool { // Subset returns whether s1 is a subset of s2. func Subset(s1, s2 Set) bool { - if s1.Len() == 0 || s2.Len() == 0 { - return false - } s1Sl := s1.Slice() for i := range s1Sl { if !s2.Has(s1Sl[i]) { @@ -345,7 +427,7 @@ func Disjoint(s1, s2 Set) bool { return false } } - return false + return true } // Intersection finds elements that exist in both sets and makes a new @@ -386,3 +468,41 @@ func Difference(s1, s2 Set) Set { func SymmetricDifference(s1, s2 Set) Set { return Union(Difference(s1, s2), Difference(s2, s1)) } + +/* Helper Functions for Balancing the Tree */ +func RightRotate(y *Node) *Node { + x := y.left + t2 := x.right + // Perform rotation + x.right = y + y.left = t2 + fmt.Println("RightRotate: return=" + x.value + "; L:" + x.left.value + "; R:" + x.right.value) + fmt.Println("RightRotate: leftNode=" + x.getLeftValue() + "; L:" + x.left.getLeftValue() + "; R:" + x.left.getRightValue()) + fmt.Println("RightRotate: rightNode=" + x.getRightValue() + "; L:" + x.right.getLeftValue() + "; R:" + x.right.getRightValue()) + // Return new root + return x +} + +func LeftRotate(x *Node) *Node { + y := x.right + t2 := y.left + // Perform rotation + y.left = x + x.right = t2 + // Return new root + fmt.Println("LeftRotate: return=" + y.value + "; L:" + y.left.value + "; R:" + y.right.value) + fmt.Println("LeftRotate: leftNode=" + y.getLeftValue() + "; L:" + y.left.getLeftValue() + "; R:" + y.left.getRightValue()) + fmt.Println("LeftRotate: rightNode=" + y.getRightValue() + "; L:" + y.right.getLeftValue() + "; R:" + y.right.getRightValue()) + return y +} + +func GetBalance(n *Node) int { + if n == nil { + return 0 + } + return n.left.Len() - n.right.Len() +} + +func CompareNodes(n1, n2 *Node) int { + return strings.Compare(n1.value, n2.value) +}