From 6b01a7f10c7f7aa62b6cafd601949d21630fdce4 Mon Sep 17 00:00:00 2001 From: Matt Goodall Date: Mon, 9 Jul 2012 14:50:35 +0100 Subject: [PATCH] Cleanly detect & signal end of stream, and fix code above. --- client.go | 5 ++++- component.go | 5 ++++- src/xmpp/client.go | 21 +++++++++++---------- src/xmpp/component.go | 3 +++ src/xmpp/stream.go | 17 +++++++++++++++-- src/xmpp/xmpp.go | 15 ++++++++++++--- 6 files changed, 49 insertions(+), 17 deletions(-) diff --git a/client.go b/client.go index 4676250..6a1a00a 100644 --- a/client.go +++ b/client.go @@ -52,7 +52,10 @@ func main() { // Log any stanzas that are not handled elsewhere. go func() { for { - stanza := x.Recv() + stanza, err := x.Recv() + if err != nil { + log.Fatal(err) + } log.Printf("* recv: %v\n", stanza) } }() diff --git a/component.go b/component.go index fd488ba..12d77fe 100644 --- a/component.go +++ b/component.go @@ -31,7 +31,10 @@ func main() { } for { - v := x.Recv() + v, err := x.Recv() + if err != nil { + log.Fatal(err) + } log.Printf("recv: %v", v) } } diff --git a/src/xmpp/client.go b/src/xmpp/client.go index 4465425..09862d3 100644 --- a/src/xmpp/client.go +++ b/src/xmpp/client.go @@ -137,14 +137,19 @@ func authenticatePlain(stream *Stream, user, password string) error { return err } - if se, err := stream.Next(nil); err != nil { + se, err := stream.Next(nil) + if err != nil { return err - } else { - if se.Name.Local == "failure" { - f := new(saslFailure) - stream.DecodeElement(f, se) - return errors.New(fmt.Sprintf("Authentication failed: %s", f.Reason.Local)) + } + switch se.Name.Local { + case "success": + if err := stream.Skip(); err != nil { + return err } + case "failure": + f := new(saslFailure) + stream.DecodeElement(f, se) + return errors.New(fmt.Sprintf("Authentication failed: %s", f.Reason.Local)) } return nil @@ -219,10 +224,6 @@ type tlsStartTLS struct { type required struct {} -type saslSuccess struct { - XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl success"` -} - type saslFailure struct { XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl failure"` Reason xml.Name `xml:",any"` diff --git a/src/xmpp/component.go b/src/xmpp/component.go index 246313a..bf0d1c1 100644 --- a/src/xmpp/component.go +++ b/src/xmpp/component.go @@ -72,6 +72,9 @@ func handshake(stream *Stream, streamId, secret string) error { if _, err := stream.Next(&xml.Name{"jabber:component:accept", "handshake"}); err != nil { return err } + if err := stream.Skip(); err != nil { + return err + } return nil } diff --git a/src/xmpp/stream.go b/src/xmpp/stream.go index a1f597f..db864af 100644 --- a/src/xmpp/stream.go +++ b/src/xmpp/stream.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/xml" "fmt" + "io" "log" "net" ) @@ -96,23 +97,35 @@ func (stream *Stream) send(b []byte) error { } // Find start of next stanza. If match is not nil the stanza's XML name -// compared and must be equal. +// is compared and must be equal. +// Bad things are very likely to happen if a call to Next() is successful but +// you don't actually decode or skip the element. func (stream *Stream) Next(match *xml.Name) (*xml.StartElement, error) { for { t, err := stream.dec.Token() if err != nil { return nil, err } - if e, ok := t.(xml.StartElement); ok { + switch e := t.(type) { + case xml.StartElement: if match != nil && e.Name != *match { return nil, fmt.Errorf("Expected %s, got %s", *match, e.Name) } return &e, nil + case xml.EndElement: + log.Printf("EOF due to %s\n", e.Name) + return nil, io.EOF } } panic("Unreachable") } +// Skip reads tokens until it has reaches the end element of the most recent +// start element that has already been read. +func (stream *Stream) Skip() error { + return stream.dec.Skip() +} + // Decode the next stanza. Works like xml.Unmarshal but reads from the stream's // connection. func (stream *Stream) Decode(v interface{}) error { diff --git a/src/xmpp/xmpp.go b/src/xmpp/xmpp.go index 08cffd3..05cba1d 100644 --- a/src/xmpp/xmpp.go +++ b/src/xmpp/xmpp.go @@ -37,8 +37,13 @@ func (x *XMPP) Send(v interface{}) { x.out <- v } -func (x *XMPP) Recv() interface{} { - return <-x.in +// Return the next stanza. +func (x *XMPP) Recv() (interface{}, error) { + v := <-x.in + if e, ok := v.(error); ok { + return nil, e + } + return v, nil } func (x *XMPP) SendRecv(iq *Iq) (*Iq, error) { @@ -107,10 +112,14 @@ func (x *XMPP) sender() { } func (x *XMPP) receiver() { + + defer close(x.in) + for { start, err := x.stream.Next(nil) if err != nil { - log.Fatal(err) + x.in <- err + return } var v interface{}