exercism/go/custom-set/stringset.go

509 lines
11 KiB
Go

package stringset
import (
"errors"
"fmt"
"strconv"
"strings"
)
const testVersion = 3
// Set is a slice of strings that you can do set operations on.
// I decided that I wanted to implement a binary tree for the storage
type Set struct {
top *Node
}
type Node struct {
value string
left *Node
right *Node
}
// New returns an empty Set
func New() Set {
s := new(Set)
s.top = &Node{}
return *s
}
// NewFromSlice takes a slice of strings and returns a Set
func NewFromSlice(s []string) Set {
ret := New()
r := &ret
for i := range s {
r.Add(s[i])
}
return *r
}
// Add adds a value to the set
func (s Set) Add(v string) {
if s.top.value == "" {
s.top.value = v
return
}
if !s.Has(v) {
bef := "Current Top Node: " + s.top.value + ";L:" + s.top.getLeftValue() + ";R:" + s.top.getRightValue()
newNode := s.top.Add(&Node{value: v})
fmt.Println(bef)
fmt.Println("Current New Node: " + newNode.value + ";L:" + newNode.getLeftValue() + ";R:" + newNode.getRightValue())
fmt.Printf("newNode.left: %v (%v)\n", &newNode.left, newNode.left)
s.top.value = newNode.value
s.top.left = newNode.left
s.top.right = newNode.right
fmt.Println("Current Top Node: " + s.top.value + ";L:" + s.top.getLeftValue() + ";R:" + s.top.getRightValue())
fmt.Printf("s.top.left: %v (%v)\n", &s.top.left, s.top.left)
fmt.Println("Add Done")
//s.top = s.top.Add(newNode)
//fmt.Println("Current Top Node: " + s.top.value + ";L:" + s.top.getLeftValue() + ";R:" + s.top.getRightValue())
}
}
// Delete removes the given value from the set
func (s Set) Delete(v string) {
if sv, err := s.findParent(v); err == nil {
var cmp int
var delNode, repNode, orphan *Node
if sv == nil {
// Deleting 'top'
delNode = s.top
if delNode.left != nil {
repNode = delNode.left
orphan = delNode.right
} else if delNode.right != nil {
repNode = delNode.right
}
if repNode == nil {
// No node to replace it with, we're done
return
}
s.top = repNode
} else {
cmp = strings.Compare(v, sv.value)
if cmp < 0 && sv.left != nil {
// It's the left node
delNode = sv.left
} else if cmp > 0 && sv.right != nil {
// It's the right node
delNode = sv.right
}
if delNode == nil {
return
}
if delNode.left != nil {
repNode = delNode.left
orphan = delNode.right
} else if delNode.right != nil {
repNode = delNode.right
}
if repNode == nil {
// No replacement node, we're done
return
}
if cmp < 0 {
sv.left = repNode
} else if cmp > 0 {
sv.right = repNode
}
}
// If we have an orphaned branch, find it a home
if orphan != nil {
s.findHome(orphan)
}
}
}
func (s Set) PrettyPrint() {
fmt.Println("=========")
s.pp(s.top, 0)
fmt.Println("=========")
}
func (s Set) pp(n *Node, indent int) {
if n.left == n || n.right == n {
// Circular reference...
fmt.Println("Circular reference... You done messed up.")
return
}
if n != nil {
if n.left != nil {
s.pp(n.left, indent+4)
}
if n.right != nil {
s.pp(n.right, indent+4)
}
for ; indent > 0; indent-- {
fmt.Print(" ")
}
fmt.Println(n.value)
}
}
func (s *Set) find(v string) *Node {
if s.top.value != "" {
sv, _ := s.top.find(v)
return sv
}
return nil
}
// findWithParent finds a node with a child value v
// a return of nil, nil means it's the top node
func (s *Set) findParent(v string) (*Node, error) {
if s.top.value == v {
// no parent, it's the top.
return nil, nil
} else {
return s.top.findParent(v)
}
return nil, errors.New("Empty Set")
}
func (s *Set) findHome(v *Node) {
if s.top.value == "" {
s.top.value = v.value
s.top.left = v.left
s.top.right = v.right
return
}
s.top.findHome(v)
}
// Has returns if the set contains the given value.
func (s *Set) Has(v string) bool {
return s.find(v) != nil
}
// IsEmpty returns whether the set is empty or not.
func (s *Set) IsEmpty() bool {
return s.top.value == ""
}
// Len returns the number of values in the set
func (s *Set) Len() int {
if !s.IsEmpty() {
return s.top.Len()
}
return 0
}
// Slice returns a string slice of the set
func (s Set) Slice() []string {
if !s.IsEmpty() {
return s.top.Slice()
}
return []string{}
}
// String converts the set to a string
func (s Set) String() string {
ret := "{"
if s.top.value != "" {
ret += s.top.String()
}
ret += "}"
return ret
}
// find looks for a node with value val, it either returns the node
// or an error stating it couldn't find it.
func (sv *Node) find(val string) (*Node, error) {
if sv.value == val {
return sv, nil
}
cmp := strings.Compare(val, sv.value)
if cmp < 0 && sv.left != nil {
return sv.left.find(val)
}
if cmp > 0 && sv.right != nil {
return sv.right.find(val)
}
return nil, errors.New("Value not found")
}
// findParent looks for the parent of the node with value val
// If nil, nil is returned, it _is_ this node.
func (sv *Node) findParent(val string) (*Node, error) {
if sv.value == val {
// This should only trigger if this is the top node of the tree
return nil, nil
}
cmp := strings.Compare(val, sv.value)
if cmp < 0 && sv.left != nil {
if sv.left.value == val {
return sv, nil
}
return sv.left.findParent(val)
}
if cmp > 0 && sv.right != nil {
if sv.right.value == val {
return sv, nil
}
return sv.right.findParent(val)
}
return nil, errors.New("Value not found")
}
func (sv *Node) findHome(v *Node) {
cmp := strings.Compare(v.value, sv.value)
if cmp < 0 {
if sv.left == nil {
sv.left = v
} else {
sv.left.findHome(v)
}
} else if cmp > 0 {
if sv.right == nil {
sv.right = v
} else {
sv.right.findHome(v)
}
} else {
// Discard the top node, find homes for it's children
if v.left != nil {
sv.findHome(v.left)
}
if v.right != nil {
sv.findHome(v.right)
}
}
}
func (sv *Node) Add(n *Node) *Node {
fmt.Println("1 Adding at Node " + sv.value)
cmp := strings.Compare(n.value, sv.value)
if cmp < 0 {
fmt.Println("2 checking to left")
if sv.left == nil {
fmt.Println("3 left is nil, add it")
sv.left = n
} else {
fmt.Println("--> 4 recurse left")
sv.left = sv.left.Add(n)
fmt.Println("<-- 5 added to left")
}
} else {
fmt.Println("6 checking to right")
if sv.right == nil {
fmt.Println("7 right is nil, add it")
sv.right = n
} else {
fmt.Println("--> 8 recurse right")
sv.right = sv.right.Add(n)
fmt.Println("<-- 9 added to right")
}
}
fmt.Println("10 let's balance now (sv: " + sv.value + ")")
balance := GetBalance(sv)
// left left unbalance
fmt.Println("11 balance: " + strconv.Itoa(balance))
newVal := sv
if balance > 1 && strings.Compare(n.value, sv.getLeftValue()) < 0 { // < sv.left.value {
fmt.Println("12 Left Left")
newVal = RightRotate(sv)
fmt.Println("13")
}
// right right unbalance
if balance < -1 && strings.Compare(n.value, sv.getRightValue()) > 0 { //n.value > sv.right.value {
fmt.Println("14 Right Right")
newVal = LeftRotate(sv)
fmt.Println("15")
}
// left right unbalance
if balance > 1 && strings.Compare(n.value, sv.getLeftValue()) > 0 { //n.value > sv.left.value {
fmt.Println("16 Left Right")
sv.left = LeftRotate(sv.left)
fmt.Println("17")
newVal = RightRotate(sv)
fmt.Println("18")
}
// right left unbalance
if balance < -1 && strings.Compare(n.value, sv.getRightValue()) < 0 { //n.value < sv.right.value {
fmt.Println("19 Right Left")
sv.right = RightRotate(sv.right)
fmt.Println("20")
newVal = LeftRotate(sv)
fmt.Println("21")
}
fmt.Println("22 end add")
return newVal
}
func (sv *Node) getLeftValue() string {
if sv.left == nil {
return ""
}
return sv.left.value
}
func (sv *Node) getRightValue() string {
if sv.right == nil {
return ""
}
return sv.right.value
}
// Len returns how many elements are in the branches
func (sv *Node) Len() int {
if sv == nil {
return 0
}
ret := 1
if sv.left != nil {
ret += sv.left.Len()
}
if sv.right != nil {
ret += sv.right.Len()
}
return ret
}
// Has checks if this branch contains the value v
func (sv *Node) Has(v string) bool {
ret, _ := sv.find(v)
return ret != nil
}
// String gets a string value of this branch
func (sv *Node) String() string {
var ret string
if sv.left != nil {
if sv.left == sv {
ret += "**, "
} else {
ret += sv.left.String() + ", "
}
}
ret += "\"" + sv.value + "\""
if sv.right != nil {
if sv.right == sv {
ret += "**, "
} else {
ret += ", " + sv.right.String()
}
}
return ret
}
// Slice returns a string slice of all values in the branch
func (sv *Node) Slice() []string {
var ret []string
if sv.left != nil {
ret = sv.left.Slice()
}
ret = append(ret, sv.value)
if sv.right != nil {
ret = append(ret, sv.right.Slice()...)
}
return ret
}
// Equal returns whether the given sets are the same.
func Equal(s1, s2 Set) bool {
return s1.String() == s2.String()
}
// Subset returns whether s1 is a subset of s2.
func Subset(s1, s2 Set) bool {
s1Sl := s1.Slice()
for i := range s1Sl {
if !s2.Has(s1Sl[i]) {
return false
}
}
return true
}
// Disjoint returns whether two sets _do not_ intersect
func Disjoint(s1, s2 Set) bool {
s1Sl := s1.Slice()
for i := range s1Sl {
if s2.Has(s1Sl[i]) {
return false
}
}
return true
}
// Intersection finds elements that exist in both sets and makes a new
// set of them
func Intersection(s1, s2 Set) Set {
var vals []string
s1Sl := s1.Slice()
for i := range s1Sl {
if s2.Has(s1Sl[i]) {
vals = append(vals, s1Sl[i])
}
}
return NewFromSlice(vals)
}
// Union gets all elements in both sets and makes a new set with them.
func Union(s1, s2 Set) Set {
var vals []string
vals = append(vals, s1.Slice()...)
vals = append(vals, s2.Slice()...)
return NewFromSlice(vals)
}
// Difference returns a Set of all elements in s1 that aren't in s2
func Difference(s1, s2 Set) Set {
var vals []string
s1Sl := s1.Slice()
for i := range s1Sl {
if !s2.Has(s1Sl[i]) {
vals = append(vals, s1Sl[i])
}
}
return NewFromSlice(vals)
}
// SymmetricDifference returns all elements from s1 & s2 that occur in only one of the
// sets.
func SymmetricDifference(s1, s2 Set) Set {
return Union(Difference(s1, s2), Difference(s2, s1))
}
/* Helper Functions for Balancing the Tree */
func RightRotate(y *Node) *Node {
x := y.left
t2 := x.right
// Perform rotation
x.right = y
y.left = t2
fmt.Println("RightRotate: return=" + x.value + "; L:" + x.left.value + "; R:" + x.right.value)
fmt.Println("RightRotate: leftNode=" + x.getLeftValue() + "; L:" + x.left.getLeftValue() + "; R:" + x.left.getRightValue())
fmt.Println("RightRotate: rightNode=" + x.getRightValue() + "; L:" + x.right.getLeftValue() + "; R:" + x.right.getRightValue())
// Return new root
return x
}
func LeftRotate(x *Node) *Node {
y := x.right
t2 := y.left
// Perform rotation
y.left = x
x.right = t2
// Return new root
fmt.Println("LeftRotate: return=" + y.value + "; L:" + y.left.value + "; R:" + y.right.value)
fmt.Println("LeftRotate: leftNode=" + y.getLeftValue() + "; L:" + y.left.getLeftValue() + "; R:" + y.left.getRightValue())
fmt.Println("LeftRotate: rightNode=" + y.getRightValue() + "; L:" + y.right.getLeftValue() + "; R:" + y.right.getRightValue())
return y
}
func GetBalance(n *Node) int {
if n == nil {
return 0
}
return n.left.Len() - n.right.Len()
}
func CompareNodes(n1, n2 *Node) int {
return strings.Compare(n1.value, n2.value)
}