Cleanly detect & signal end of stream, and fix code above.

This commit is contained in:
Matt Goodall 2012-07-09 14:50:35 +01:00
parent 5536f034bd
commit 6b01a7f10c
6 changed files with 49 additions and 17 deletions

View File

@ -52,7 +52,10 @@ func main() {
// Log any stanzas that are not handled elsewhere.
go func() {
for {
stanza := x.Recv()
stanza, err := x.Recv()
if err != nil {
log.Fatal(err)
}
log.Printf("* recv: %v\n", stanza)
}
}()

View File

@ -31,7 +31,10 @@ func main() {
}
for {
v := x.Recv()
v, err := x.Recv()
if err != nil {
log.Fatal(err)
}
log.Printf("recv: %v", v)
}
}

View File

@ -137,14 +137,19 @@ func authenticatePlain(stream *Stream, user, password string) error {
return err
}
if se, err := stream.Next(nil); err != nil {
se, err := stream.Next(nil)
if err != nil {
return err
} else {
if se.Name.Local == "failure" {
f := new(saslFailure)
stream.DecodeElement(f, se)
return errors.New(fmt.Sprintf("Authentication failed: %s", f.Reason.Local))
}
switch se.Name.Local {
case "success":
if err := stream.Skip(); err != nil {
return err
}
case "failure":
f := new(saslFailure)
stream.DecodeElement(f, se)
return errors.New(fmt.Sprintf("Authentication failed: %s", f.Reason.Local))
}
return nil
@ -219,10 +224,6 @@ type tlsStartTLS struct {
type required struct {}
type saslSuccess struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl success"`
}
type saslFailure struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl failure"`
Reason xml.Name `xml:",any"`

View File

@ -72,6 +72,9 @@ func handshake(stream *Stream, streamId, secret string) error {
if _, err := stream.Next(&xml.Name{"jabber:component:accept", "handshake"}); err != nil {
return err
}
if err := stream.Skip(); err != nil {
return err
}
return nil
}

View File

@ -5,6 +5,7 @@ import (
"crypto/tls"
"encoding/xml"
"fmt"
"io"
"log"
"net"
)
@ -96,23 +97,35 @@ func (stream *Stream) send(b []byte) error {
}
// Find start of next stanza. If match is not nil the stanza's XML name
// compared and must be equal.
// is compared and must be equal.
// Bad things are very likely to happen if a call to Next() is successful but
// you don't actually decode or skip the element.
func (stream *Stream) Next(match *xml.Name) (*xml.StartElement, error) {
for {
t, err := stream.dec.Token()
if err != nil {
return nil, err
}
if e, ok := t.(xml.StartElement); ok {
switch e := t.(type) {
case xml.StartElement:
if match != nil && e.Name != *match {
return nil, fmt.Errorf("Expected %s, got %s", *match, e.Name)
}
return &e, nil
case xml.EndElement:
log.Printf("EOF due to %s\n", e.Name)
return nil, io.EOF
}
}
panic("Unreachable")
}
// Skip reads tokens until it has reaches the end element of the most recent
// start element that has already been read.
func (stream *Stream) Skip() error {
return stream.dec.Skip()
}
// Decode the next stanza. Works like xml.Unmarshal but reads from the stream's
// connection.
func (stream *Stream) Decode(v interface{}) error {

View File

@ -37,8 +37,13 @@ func (x *XMPP) Send(v interface{}) {
x.out <- v
}
func (x *XMPP) Recv() interface{} {
return <-x.in
// Return the next stanza.
func (x *XMPP) Recv() (interface{}, error) {
v := <-x.in
if e, ok := v.(error); ok {
return nil, e
}
return v, nil
}
func (x *XMPP) SendRecv(iq *Iq) (*Iq, error) {
@ -107,10 +112,14 @@ func (x *XMPP) sender() {
}
func (x *XMPP) receiver() {
defer close(x.in)
for {
start, err := x.stream.Next(nil)
if err != nil {
log.Fatal(err)
x.in <- err
return
}
var v interface{}