Looks like it's pretty much working!

Yeah, I'm pretty awesome.
This commit is contained in:
Brian Buller 2015-08-14 17:21:18 -05:00
parent 9afaeee97d
commit 14d105740a
1 changed files with 91 additions and 38 deletions

View File

@ -9,36 +9,56 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"time"
) )
const PROGRAM_NAME = "fullscrape" const PROGRAM_NAME = "fullscrape"
func main() { func main() {
if len(os.Args) <= 3 {
fmt.Print("Usage: " + PROGRAM_NAME + " <url> <output-directory> <depth> [-nx]\n")
fmt.Print(" -n: Don't rewrite urls in source files to work locally\n")
fmt.Print(" -x: Cross domains when following links\n")
os.Exit(1)
}
fix_urls := true
cross_domains := false
req_url := os.Args[1] //"http://golang.org/" req_url := os.Args[1] //"http://golang.org/"
out_dir := os.Args[2] out_dir := os.Args[2]
req_depth, err := strconv.Atoi(os.Args[3]) //4 depthFlag := -1
if err != nil { norewriteFlag := false
fmt.Print("Invalid Depth specified. Please give a number.\n") crossdomainFlag := false
fmt.Print("Usage: " + PROGRAM_NAME + " <url> <output-directory> <depth> [-n]\n") throttleFlag := 1000
os.Exit(1) var err error
}
if len(os.Args) > 3 { if len(os.Args) > 3 {
tst_arg := os.Args[4] tst := os.Args[3]
if strings.Index(tst_arg, "n") != -1 { depthArg := strings.IndexRune(tst, 'd')
fix_urls = false if depthArg >= 0 {
} else if strings.Index(tst_arg, "x") != -1 { // The actual depth value should either be depthArg+1
cross_domains = true // or, if that is '=', depthArg+2
if tst[depthArg+1] == '=' {
depthFlag, err = strconv.Atoi(strings.Split(tst, "")[depthArg+2])
} else {
depthFlag, err = strconv.Atoi(strings.Split(tst, "")[depthArg+1])
}
if err != nil {
fmt.Printf("Invalid depth given (must be an integer): %s\n", depthFlag)
os.Exit(1)
}
}
norewriteFlag = (strings.IndexRune(tst, 'n') >= 0)
crossdomainFlag = (strings.IndexRune(tst, 'x') >= 0)
throttleArg := strings.IndexRune(tst, 't')
if throttleArg >= 0 {
// The actual throttle value should either be throttleArg+1...
// or, if that is '=', throttleArg+2...
if tst[depthArg+1] == '=' {
// The throttle argument MUST have a space after it
throttleFlag, err = strconv.Atoi(strings.Split(tst, "")[throttleArg+2])
} else {
throttleFlag, err = strconv.Atoi(strings.Split(tst, "")[throttleArg+1])
}
if err != nil {
fmt.Printf("Invalid depth given (must be milliseconds as an integer): %s\n", depthFlag)
os.Exit(1)
}
} }
} }
if err = CreateDirIfNotExist(out_dir); err != nil {
if err := CreateDirIfNotExist(out_dir); err != nil {
fmt.Print("Unable to create initial directory %s\n", out_dir) fmt.Print("Unable to create initial directory %s\n", out_dir)
fmt.Print("Error: %s\n", err) fmt.Print("Error: %s\n", err)
os.Exit(1) os.Exit(1)
@ -54,10 +74,12 @@ func main() {
} }
c.rootUrl = req_url c.rootUrl = req_url
c.outDir = out_dir c.outDir = out_dir
c.fixUrls = fix_urls c.fixUrls = norewriteFlag
c.xDomain = cross_domains c.xDomain = crossdomainFlag
c.depth = depthFlag
c.throttle = time.Duration(throttleFlag)
c.Crawl(req_url, req_depth) c.Crawl()
} }
type unprocessed struct { type unprocessed struct {
@ -66,28 +88,48 @@ type unprocessed struct {
} }
type Crawler struct { type Crawler struct {
rootUrl string rootUrl string
outDir string outDir string
fixUrls bool fixUrls bool
xDomain bool xDomain bool
depth int
throttle time.Duration
} }
func (c *Crawler) Crawl(url string, depth int) { func (c *Crawler) Crawl() {
if c.depth >= 0 {
fmt.Printf("Processing %s with depth %d (Norewrite: %t, XDomain: %t, Throttle: %d)\n", c.rootUrl, c.depth, c.fixUrls, c.xDomain, c.throttle)
} else {
fmt.Printf("Processing %s (Norewrite: %t, XDomain: %t, Throttle: %d)\n", c.rootUrl, c.fixUrls, c.xDomain, c.throttle)
}
// Setup channel for inputs to be processed // Setup channel for inputs to be processed
up := make(chan unprocessed, 0) up := make(chan unprocessed, 0)
// Kick off processing and count how many pages are left to process // Kick off processing and count how many pages are left to process
go c.getPage(url, depth, up) go c.getPage(c.rootUrl, c.depth, up)
outstanding := 1 outstanding := 1
visited := make(map[string]bool) visited := make(map[string]bool)
status := fmt.Sprintf("Files %d/%d", len(visited), outstanding+len(visited))
for outstanding > 0 { for outstanding > 0 {
done := len(visited) - outstanding
if done < 0 {
done = 0
}
fmt.Print(strings.Repeat("", len(status)))
status = fmt.Sprintf("Files %d/%d", done, len(visited))
fmt.Print(status)
if c.throttle > 0 {
time.Sleep(time.Millisecond * c.throttle)
}
// Pop a visit from the channel // Pop a visit from the channel
next := <-up next := <-up
outstanding-- outstanding--
// If we're too deep, skip it // If we're too deep, skip it
if next.depth <= 0 { if next.depth == 0 {
continue continue
} }
@ -101,20 +143,23 @@ func (c *Crawler) Crawl(url string, depth int) {
// All good to visit them // All good to visit them
outstanding++ outstanding++
visited[link] = true visited[link] = true
go c.getPage(link, depth, up) go c.getPage(link, next.depth, up)
} }
} }
fmt.Print(strings.Repeat("", len(status)))
status = fmt.Sprintf("Files %d/%d", len(visited), len(visited))
fmt.Printf("%s\n", status)
} }
func (c *Crawler) getPage(url string, depth int, r chan unprocessed) { func (c *Crawler) getPage(url string, depth int, r chan unprocessed) {
_, urls, err := c.Fetch(url) _, urls, err := c.Fetch(url)
//body, urls, err := c.Fetch(url) //body, urls, err := c.Fetch(url)
fmt.Printf("Found: %s\n", url) //fmt.Printf("Found: %s\n", url)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
fmt.Printf("Pulled URLS: %s\n", urls) //fmt.Printf("Pulled URLS: %s\n", urls)
r <- unprocessed{depth - 1, urls} r <- unprocessed{depth - 1, urls}
} }
@ -123,7 +168,7 @@ func (c *Crawler) Fetch(url string) (string, []string, error) {
urls := make([]string, 0) urls := make([]string, 0)
// Ok, go get URL // Ok, go get URL
response, err := http.Get(url) response, err := http.Get(url)
if err != nil { if err != nil || response.StatusCode != 200 {
return "", nil, err return "", nil, err
} }
body, err := ioutil.ReadAll(response.Body) body, err := ioutil.ReadAll(response.Body)
@ -165,9 +210,9 @@ func (c *Crawler) Fetch(url string) (string, []string, error) {
for { for {
tt := z.Next() tt := z.Next()
switch { switch {
case tt == html.StartTagToken: case tt == html.StartTagToken || tt == html.SelfClosingTagToken:
t := z.Token() t := z.Token()
if t.Data == "a" || t.Data == "link" { if t.Data == "link" || t.Data == "a" {
for _, a := range t.Attr { for _, a := range t.Attr {
if a.Key == "href" { if a.Key == "href" {
if c.CheckUrl(a.Val) { if c.CheckUrl(a.Val) {
@ -176,7 +221,7 @@ func (c *Crawler) Fetch(url string) (string, []string, error) {
break break
} }
} }
} else if t.Data == "img" { } else if t.Data == "img" || t.Data == "script" {
for _, a := range t.Attr { for _, a := range t.Attr {
if a.Key == "src" { if a.Key == "src" {
if c.CheckUrl(a.Val) { if c.CheckUrl(a.Val) {
@ -203,6 +248,14 @@ func (c *Crawler) Fetch(url string) (string, []string, error) {
* The main purpose is for cross-domain checks * The main purpose is for cross-domain checks
*/ */
func (c *Crawler) CheckUrl(url string) bool { func (c *Crawler) CheckUrl(url string) bool {
// Ignore anchor urls
if strings.IndexRune(url, '#') >= 0 {
return false
}
// Ignore "mailto" links
if strings.HasPrefix(url, "mailto:") {
return false
}
if !c.xDomain { if !c.xDomain {
if strings.HasPrefix(url, "http") { if strings.HasPrefix(url, "http") {
return strings.HasPrefix(url, c.rootUrl) return strings.HasPrefix(url, c.rootUrl)
@ -234,6 +287,6 @@ func CreateDirIfNotExist(dir string) error {
func WriteFile(d string, filename string) error { func WriteFile(d string, filename string) error {
do := []byte(d) do := []byte(d)
fmt.Printf("Writing %s\n", filename) //fmt.Printf("Writing %s\n", filename)
return ioutil.WriteFile(filename, do, 0664) return ioutil.WriteFile(filename, do, 0664)
} }