1
0
Fork 0

Cleanly detect & signal end of stream, and fix code above.

This commit is contained in:
Matt Goodall 2012-07-09 14:50:35 +01:00
parent 5536f034bd
commit 6b01a7f10c
6 changed files with 49 additions and 17 deletions

View File

@ -52,7 +52,10 @@ func main() {
// Log any stanzas that are not handled elsewhere. // Log any stanzas that are not handled elsewhere.
go func() { go func() {
for { for {
stanza := x.Recv() stanza, err := x.Recv()
if err != nil {
log.Fatal(err)
}
log.Printf("* recv: %v\n", stanza) log.Printf("* recv: %v\n", stanza)
} }
}() }()

View File

@ -31,7 +31,10 @@ func main() {
} }
for { for {
v := x.Recv() v, err := x.Recv()
if err != nil {
log.Fatal(err)
}
log.Printf("recv: %v", v) log.Printf("recv: %v", v)
} }
} }

View File

@ -137,14 +137,19 @@ func authenticatePlain(stream *Stream, user, password string) error {
return err return err
} }
if se, err := stream.Next(nil); err != nil { se, err := stream.Next(nil)
if err != nil {
return err return err
} else { }
if se.Name.Local == "failure" { switch se.Name.Local {
f := new(saslFailure) case "success":
stream.DecodeElement(f, se) if err := stream.Skip(); err != nil {
return errors.New(fmt.Sprintf("Authentication failed: %s", f.Reason.Local)) return err
} }
case "failure":
f := new(saslFailure)
stream.DecodeElement(f, se)
return errors.New(fmt.Sprintf("Authentication failed: %s", f.Reason.Local))
} }
return nil return nil
@ -219,10 +224,6 @@ type tlsStartTLS struct {
type required struct {} type required struct {}
type saslSuccess struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl success"`
}
type saslFailure struct { type saslFailure struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl failure"` XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl failure"`
Reason xml.Name `xml:",any"` Reason xml.Name `xml:",any"`

View File

@ -72,6 +72,9 @@ func handshake(stream *Stream, streamId, secret string) error {
if _, err := stream.Next(&xml.Name{"jabber:component:accept", "handshake"}); err != nil { if _, err := stream.Next(&xml.Name{"jabber:component:accept", "handshake"}); err != nil {
return err return err
} }
if err := stream.Skip(); err != nil {
return err
}
return nil return nil
} }

View File

@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"io"
"log" "log"
"net" "net"
) )
@ -96,23 +97,35 @@ func (stream *Stream) send(b []byte) error {
} }
// Find start of next stanza. If match is not nil the stanza's XML name // Find start of next stanza. If match is not nil the stanza's XML name
// compared and must be equal. // is compared and must be equal.
// Bad things are very likely to happen if a call to Next() is successful but
// you don't actually decode or skip the element.
func (stream *Stream) Next(match *xml.Name) (*xml.StartElement, error) { func (stream *Stream) Next(match *xml.Name) (*xml.StartElement, error) {
for { for {
t, err := stream.dec.Token() t, err := stream.dec.Token()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if e, ok := t.(xml.StartElement); ok { switch e := t.(type) {
case xml.StartElement:
if match != nil && e.Name != *match { if match != nil && e.Name != *match {
return nil, fmt.Errorf("Expected %s, got %s", *match, e.Name) return nil, fmt.Errorf("Expected %s, got %s", *match, e.Name)
} }
return &e, nil return &e, nil
case xml.EndElement:
log.Printf("EOF due to %s\n", e.Name)
return nil, io.EOF
} }
} }
panic("Unreachable") panic("Unreachable")
} }
// Skip reads tokens until it has reaches the end element of the most recent
// start element that has already been read.
func (stream *Stream) Skip() error {
return stream.dec.Skip()
}
// Decode the next stanza. Works like xml.Unmarshal but reads from the stream's // Decode the next stanza. Works like xml.Unmarshal but reads from the stream's
// connection. // connection.
func (stream *Stream) Decode(v interface{}) error { func (stream *Stream) Decode(v interface{}) error {

View File

@ -37,8 +37,13 @@ func (x *XMPP) Send(v interface{}) {
x.out <- v x.out <- v
} }
func (x *XMPP) Recv() interface{} { // Return the next stanza.
return <-x.in func (x *XMPP) Recv() (interface{}, error) {
v := <-x.in
if e, ok := v.(error); ok {
return nil, e
}
return v, nil
} }
func (x *XMPP) SendRecv(iq *Iq) (*Iq, error) { func (x *XMPP) SendRecv(iq *Iq) (*Iq, error) {
@ -107,10 +112,14 @@ func (x *XMPP) sender() {
} }
func (x *XMPP) receiver() { func (x *XMPP) receiver() {
defer close(x.in)
for { for {
start, err := x.stream.Next(nil) start, err := x.stream.Next(nil)
if err != nil { if err != nil {
log.Fatal(err) x.in <- err
return
} }
var v interface{} var v interface{}